From 1bbdf6cd985028db64b0287c4b6748e94e54cc40 Mon Sep 17 00:00:00 2001 From: Bert JW Regeer Date: Mon, 17 Jan 2022 17:53:50 -0700 Subject: [PATCH] Flush data from the application thread To speed up how soon the connected client sees data we now attempt to flush data from the application thread when we get new data to write to the socket. This saves us the need to wake up the main thread, which would then return from select(), process all sockets, look for the ones that are writable, and then call select() again. When that select() would return it would finally start writing data to the remote socket. There was also no gaurantee that the main thread would get the lock for the output buffers, and it would not be able to write any data at all thereby looping on select() until the application thread had written enough data to the buffers for it to hit the high water mark, or the response was fully buffered, potentially overflowing from memory buffers to disk. If the socket is not ready for data, due it being non-blocking, we will not flush any data at all, and will go notify/wake up the main thread to start sending the data when the socket is ready. Delivery of first byte from the WSGI application to the remote client is now faster, and it may alleviate buffer pressure. Especially if the remote client is connected over localhost, as is the case with a load balancer in front of waitress. --- src/waitress/channel.py | 46 ++++++++++++++++---- tests/test_channel.py | 95 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 122 insertions(+), 19 deletions(-) diff --git a/src/waitress/channel.py b/src/waitress/channel.py index 7d1f385b..948b4986 100644 --- a/src/waitress/channel.py +++ b/src/waitress/channel.py @@ -78,6 +78,7 @@ def check_client_disconnected(self): may occasionally check if the client has disconnected and interrupt execution. """ + return not self.connected def writable(self): @@ -116,23 +117,30 @@ def handle_write(self): # right now. flush = None + self._flush_exception(flush) + + if self.close_when_flushed and not self.total_outbufs_len: + self.close_when_flushed = False + self.will_close = True + + if self.will_close: + self.handle_close() + + def _flush_exception(self, flush): if flush: try: - flush() + return (flush(), False) except OSError: if self.adj.log_socket_errors: self.logger.exception("Socket error") self.will_close = True + + return (False, True) except Exception: # pragma: nocover self.logger.exception("Unexpected exception when flushing") self.will_close = True - if self.close_when_flushed and not self.total_outbufs_len: - self.close_when_flushed = False - self.will_close = True - - if self.will_close: - self.handle_close() + return (False, True) def readable(self): # We might want to read more requests. We can only do this if: @@ -190,6 +198,7 @@ def received(self, data): Receives input asynchronously and assigns one or more requests to the channel. """ + if not data: return False @@ -201,6 +210,7 @@ def received(self, data): # if there are requests queued, we can not send the continue # header yet since the responses need to be kept in order + if ( self.request.expect_continue and self.request.headers_finished @@ -215,6 +225,7 @@ def received(self, data): if not self.request.empty: self.requests.append(self.request) + if len(self.requests) == 1: # self.requests was empty before so the main thread # is in charge of starting the task. Otherwise, @@ -363,7 +374,14 @@ def write_soon(self, data): self.total_outbufs_len += num_bytes if self.total_outbufs_len >= self.adj.send_bytes: - self.server.pull_trigger() + (flushed, exception) = self._flush_exception(self._flush_some) + + if ( + exception + or not flushed + or self.total_outbufs_len >= self.adj.send_bytes + ): + self.server.pull_trigger() return num_bytes @@ -374,6 +392,17 @@ def _flush_outbufs_below_high_watermark(self): if self.total_outbufs_len > self.adj.outbuf_high_watermark: with self.outbuf_lock: + (_, exception) = self._flush_exception(self._flush_some) + + if exception: + # An exception happened while flushing, wake up the main + # thread, then wait for it to decide what to do next + # (probably close the socket, and then just return) + self.server.pull_trigger() + self.outbuf_lock.wait() + + return + while ( self.connected and self.total_outbufs_len > self.adj.outbuf_high_watermark @@ -460,6 +489,7 @@ def service(self): # Add new task to process the next request with self.requests_lock: self.requests.pop(0) + if self.connected and self.requests: self.server.add_task(self) elif ( diff --git a/tests/test_channel.py b/tests/test_channel.py index d86dbbe7..b1c317d4 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -213,6 +213,13 @@ def test_write_soon_empty_byte(self): def test_write_soon_nonempty_byte(self): inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + wrote = inst.write_soon(b"a") self.assertEqual(wrote, 1) self.assertEqual(len(inst.outbufs[0]), 1) @@ -224,14 +231,19 @@ def test_write_soon_filewrapper(self): wrapper = ReadOnlyFileBasedBuffer(f, 8192) wrapper.prepare() inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + outbufs = inst.outbufs - orig_outbuf = outbufs[0] wrote = inst.write_soon(wrapper) self.assertEqual(wrote, 3) - self.assertEqual(len(outbufs), 3) - self.assertEqual(outbufs[0], orig_outbuf) - self.assertEqual(outbufs[1], wrapper) - self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer") + self.assertEqual(len(outbufs), 2) + self.assertEqual(outbufs[0], wrapper) + self.assertEqual(outbufs[1].__class__.__name__, "OverflowableBuffer") def test_write_soon_disconnected(self): from waitress.channel import ClientDisconnected @@ -253,16 +265,29 @@ def dummy_flush(): def test_write_soon_rotates_outbuf_on_overflow(self): inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + inst.adj.outbuf_high_watermark = 3 inst.current_outbuf_count = 4 wrote = inst.write_soon(b"xyz") self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") + self.assertEqual(len(inst.outbufs), 1) + self.assertEqual(inst.outbufs[0].get(), b"xyz") def test_write_soon_waits_on_backpressure(self): inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush + def send(_): + return 0 + + sock.send = send + inst.adj.outbuf_high_watermark = 3 inst.total_outbufs_len = 4 inst.current_outbuf_count = 4 @@ -275,11 +300,59 @@ def wait(self): inst.outbuf_lock = Lock() wrote = inst.write_soon(b"xyz") self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") + self.assertEqual(len(inst.outbufs), 1) + self.assertEqual(inst.outbufs[0].get(), b"xyz") self.assertTrue(inst.outbuf_lock.waited) + def test_write_soon_attempts_flush_high_water_and_exception(self): + from waitress.channel import ClientDisconnected + + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush, it will raise Exception, which + # disconnects the remote end + def send(_): + inst.connected = False + raise Exception() + + sock.send = send + + inst.adj.outbuf_high_watermark = 3 + inst.total_outbufs_len = 4 + inst.current_outbuf_count = 4 + + inst.outbufs[0].append(b"test") + + class Lock(DummyLock): + def wait(self): + inst.total_outbufs_len = 0 + super().wait() + + inst.outbuf_lock = Lock() + self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"xyz")) + + # Validate we woke up the main thread to deal with the exception of + # trying to send + self.assertTrue(inst.outbuf_lock.waited) + self.assertTrue(inst.server.trigger_pulled) + + def test_write_soon_flush_and_exception(self): + inst, sock, map = self._makeOneWithMap() + + # _flush_some will no longer flush, it will raise Exception, which + # disconnects the remote end + def send(_): + inst.connected = False + raise Exception() + + sock.send = send + + wrote = inst.write_soon(b"xyz") + self.assertEqual(wrote, 3) + # Validate we woke up the main thread to deal with the exception of + # trying to send + self.assertTrue(inst.server.trigger_pulled) + def test_handle_write_notify_after_flush(self): inst, sock, map = self._makeOneWithMap() inst.requests = [True]