Skip to content

Commit

Permalink
Merge pull request #364 from Pylons/feature/flush-from-app-thread
Browse files Browse the repository at this point in the history
Flush data from the application thread
  • Loading branch information
mmerickel authored Jan 18, 2022
2 parents 7acbb87 + 1bbdf6c commit 759f4d1
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 19 deletions.
46 changes: 38 additions & 8 deletions src/waitress/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def check_client_disconnected(self):
may occasionally check if the client has disconnected and interrupt
execution.
"""

return not self.connected

def writable(self):
Expand Down Expand Up @@ -116,23 +117,30 @@ def handle_write(self):
# right now.
flush = None

self._flush_exception(flush)

if self.close_when_flushed and not self.total_outbufs_len:
self.close_when_flushed = False
self.will_close = True

if self.will_close:
self.handle_close()

def _flush_exception(self, flush):
if flush:
try:
flush()
return (flush(), False)
except OSError:
if self.adj.log_socket_errors:
self.logger.exception("Socket error")
self.will_close = True

return (False, True)
except Exception: # pragma: nocover
self.logger.exception("Unexpected exception when flushing")
self.will_close = True

if self.close_when_flushed and not self.total_outbufs_len:
self.close_when_flushed = False
self.will_close = True

if self.will_close:
self.handle_close()
return (False, True)

def readable(self):
# We might want to read more requests. We can only do this if:
Expand Down Expand Up @@ -190,6 +198,7 @@ def received(self, data):
Receives input asynchronously and assigns one or more requests to the
channel.
"""

if not data:
return False

Expand All @@ -201,6 +210,7 @@ def received(self, data):

# if there are requests queued, we can not send the continue
# header yet since the responses need to be kept in order

if (
self.request.expect_continue
and self.request.headers_finished
Expand All @@ -215,6 +225,7 @@ def received(self, data):

if not self.request.empty:
self.requests.append(self.request)

if len(self.requests) == 1:
# self.requests was empty before so the main thread
# is in charge of starting the task. Otherwise,
Expand Down Expand Up @@ -363,7 +374,14 @@ def write_soon(self, data):
self.total_outbufs_len += num_bytes

if self.total_outbufs_len >= self.adj.send_bytes:
self.server.pull_trigger()
(flushed, exception) = self._flush_exception(self._flush_some)

if (
exception
or not flushed
or self.total_outbufs_len >= self.adj.send_bytes
):
self.server.pull_trigger()

return num_bytes

Expand All @@ -374,6 +392,17 @@ def _flush_outbufs_below_high_watermark(self):

if self.total_outbufs_len > self.adj.outbuf_high_watermark:
with self.outbuf_lock:
(_, exception) = self._flush_exception(self._flush_some)

if exception:
# An exception happened while flushing, wake up the main
# thread, then wait for it to decide what to do next
# (probably close the socket, and then just return)
self.server.pull_trigger()
self.outbuf_lock.wait()

return

while (
self.connected
and self.total_outbufs_len > self.adj.outbuf_high_watermark
Expand Down Expand Up @@ -460,6 +489,7 @@ def service(self):
# Add new task to process the next request
with self.requests_lock:
self.requests.pop(0)

if self.connected and self.requests:
self.server.add_task(self)
elif (
Expand Down
95 changes: 84 additions & 11 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ def test_write_soon_empty_byte(self):

def test_write_soon_nonempty_byte(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

wrote = inst.write_soon(b"a")
self.assertEqual(wrote, 1)
self.assertEqual(len(inst.outbufs[0]), 1)
Expand All @@ -224,14 +231,19 @@ def test_write_soon_filewrapper(self):
wrapper = ReadOnlyFileBasedBuffer(f, 8192)
wrapper.prepare()
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

outbufs = inst.outbufs
orig_outbuf = outbufs[0]
wrote = inst.write_soon(wrapper)
self.assertEqual(wrote, 3)
self.assertEqual(len(outbufs), 3)
self.assertEqual(outbufs[0], orig_outbuf)
self.assertEqual(outbufs[1], wrapper)
self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer")
self.assertEqual(len(outbufs), 2)
self.assertEqual(outbufs[0], wrapper)
self.assertEqual(outbufs[1].__class__.__name__, "OverflowableBuffer")

def test_write_soon_disconnected(self):
from waitress.channel import ClientDisconnected
Expand All @@ -253,16 +265,29 @@ def dummy_flush():

def test_write_soon_rotates_outbuf_on_overflow(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

inst.adj.outbuf_high_watermark = 3
inst.current_outbuf_count = 4
wrote = inst.write_soon(b"xyz")
self.assertEqual(wrote, 3)
self.assertEqual(len(inst.outbufs), 2)
self.assertEqual(inst.outbufs[0].get(), b"")
self.assertEqual(inst.outbufs[1].get(), b"xyz")
self.assertEqual(len(inst.outbufs), 1)
self.assertEqual(inst.outbufs[0].get(), b"xyz")

def test_write_soon_waits_on_backpressure(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush
def send(_):
return 0

sock.send = send

inst.adj.outbuf_high_watermark = 3
inst.total_outbufs_len = 4
inst.current_outbuf_count = 4
Expand All @@ -275,11 +300,59 @@ def wait(self):
inst.outbuf_lock = Lock()
wrote = inst.write_soon(b"xyz")
self.assertEqual(wrote, 3)
self.assertEqual(len(inst.outbufs), 2)
self.assertEqual(inst.outbufs[0].get(), b"")
self.assertEqual(inst.outbufs[1].get(), b"xyz")
self.assertEqual(len(inst.outbufs), 1)
self.assertEqual(inst.outbufs[0].get(), b"xyz")
self.assertTrue(inst.outbuf_lock.waited)

def test_write_soon_attempts_flush_high_water_and_exception(self):
from waitress.channel import ClientDisconnected

inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush, it will raise Exception, which
# disconnects the remote end
def send(_):
inst.connected = False
raise Exception()

sock.send = send

inst.adj.outbuf_high_watermark = 3
inst.total_outbufs_len = 4
inst.current_outbuf_count = 4

inst.outbufs[0].append(b"test")

class Lock(DummyLock):
def wait(self):
inst.total_outbufs_len = 0
super().wait()

inst.outbuf_lock = Lock()
self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"xyz"))

# Validate we woke up the main thread to deal with the exception of
# trying to send
self.assertTrue(inst.outbuf_lock.waited)
self.assertTrue(inst.server.trigger_pulled)

def test_write_soon_flush_and_exception(self):
inst, sock, map = self._makeOneWithMap()

# _flush_some will no longer flush, it will raise Exception, which
# disconnects the remote end
def send(_):
inst.connected = False
raise Exception()

sock.send = send

wrote = inst.write_soon(b"xyz")
self.assertEqual(wrote, 3)
# Validate we woke up the main thread to deal with the exception of
# trying to send
self.assertTrue(inst.server.trigger_pulled)

def test_handle_write_notify_after_flush(self):
inst, sock, map = self._makeOneWithMap()
inst.requests = [True]
Expand Down

0 comments on commit 759f4d1

Please sign in to comment.