diff --git a/aiohttp/file_sender.py b/aiohttp/file_sender.py index cd534fb0820..197386f0562 100644 --- a/aiohttp/file_sender.py +++ b/aiohttp/file_sender.py @@ -21,6 +21,8 @@ def _sendfile_cb(self, fut, out_fd, in_fd, offset, count, loop, registered): if registered: loop.remove_writer(out_fd) + if fut.cancelled(): + return try: n = os.sendfile(out_fd, in_fd, offset, count) if n == 0: # EOF reached @@ -39,34 +41,51 @@ def _sendfile_cb(self, fut, out_fd, in_fd, offset, @asyncio.coroutine def _sendfile_system(self, request, resp, fobj, count): - """ - Write `count` bytes of `fobj` to `resp` using - the ``sendfile`` system call. - - `request` should be a :obj:`aiohttp.web.Request` instance. - - `resp` should be a :obj:`aiohttp.web.StreamResponse` instance. + # Write count bytes of fobj to resp using + # the os.sendfile system call. + # + # request should be a aiohttp.web.Request instance. + # + # resp should be a aiohttp.web.StreamResponse instance. + # + # fobj should be an open file object. + # + # count should be an integer > 0. - `fobj` should be an open file object. - - `count` should be an integer > 0. - """ transport = request.transport if transport.get_extra_info("sslcontext"): yield from self._sendfile_fallback(request, resp, fobj, count) return - yield from resp.drain() + def _send_headers(resp_impl): + # Durty hack required for + # https://github.com/KeepSafe/aiohttp/issues/1093 + # don't send headers in sendfile mode + pass + + resp._send_headers = _send_headers + yield from resp.prepare(request) loop = request.app.loop # See https://github.com/KeepSafe/aiohttp/issues/958 for details + + # send headers + headers = ['HTTP/{0.major}.{0.minor} 200 OK\r\n'.format( + request.version)] + for hdr, val in resp.headers.items(): + headers.append('{}: {}\r\n'.format(hdr, val)) + headers.append('\r\n') + out_socket = transport.get_extra_info("socket").dup() + out_socket.setblocking(False) out_fd = out_socket.fileno() in_fd = fobj.fileno() - fut = create_future(loop) try: + yield from loop.sock_sendall(out_socket, + ''.join(headers).encode('utf-8')) + fut = create_future(loop) self._sendfile_cb(fut, out_fd, in_fd, 0, count, loop, False) yield from fut @@ -75,15 +94,16 @@ def _sendfile_system(self, request, resp, fobj, count): @asyncio.coroutine def _sendfile_fallback(self, request, resp, fobj, count): - """ - Mimic the :meth:`_sendfile_system` method, but without using the - ``sendfile`` system call. This should be used on systems that don't - support the ``sendfile`` system call. - - To avoid blocking the event loop & to keep memory usage low, `fobj` is - transferred in chunks controlled by the `chunk_size` argument to - :class:`StaticRoute`. - """ + # Mimic the _sendfile_system() method, but without using the + # os.sendfile() system call. This should be used on systems + # that don't support the os.sendfile(). + + # To avoid blocking the event loop & to keep memory usage low, + # fobj is transferred in chunks controlled by the + # constructor's chunk_size argument. + + yield from resp.prepare(request) + chunk_size = self._chunk_size chunk = fobj.read(chunk_size) @@ -102,6 +122,7 @@ def _sendfile_fallback(self, request, resp, fobj, count): @asyncio.coroutine def send(self, request, filepath): + """Send filepath to client using request.""" st = filepath.stat() modsince = request.if_modified_since @@ -124,8 +145,6 @@ def send(self, request, filepath): resp.content_length = file_size resp.set_tcp_cork(True) try: - yield from resp.prepare(request) - with filepath.open('rb') as f: yield from self._sendfile(request, resp, f, file_size) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 055b22ab4b0..b4062397c86 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -327,24 +327,24 @@ def new_func(self): @contextlib.contextmanager -def loop_context(): +def loop_context(loop_factory=asyncio.new_event_loop): """A contextmanager that creates an event_loop, for test purposes. Handles the creation and cleanup of a test loop. """ - loop = setup_test_loop() + loop = setup_test_loop(loop_factory) yield loop teardown_test_loop(loop) -def setup_test_loop(): +def setup_test_loop(loop_factory=asyncio.new_event_loop): """Create and return an asyncio.BaseEventLoop instance. The caller should also call teardown_test_loop, once they are done with the loop. """ - loop = asyncio.new_event_loop() + loop = loop_factory() asyncio.set_event_loop(None) return loop diff --git a/aiohttp/web_reqrep.py b/aiohttp/web_reqrep.py index f98bfe1c9b6..eaee8b9baf6 100644 --- a/aiohttp/web_reqrep.py +++ b/aiohttp/web_reqrep.py @@ -724,9 +724,15 @@ def _start(self, request): resp_impl.transport.set_tcp_nodelay(self._tcp_nodelay) resp_impl.transport.set_tcp_cork(self._tcp_cork) - resp_impl.send_headers() + self._send_headers(resp_impl) return resp_impl + def _send_headers(self, resp_impl): + # Durty hack required for + # https://github.com/KeepSafe/aiohttp/issues/1093 + # File sender may override it + resp_impl.send_headers() + def write(self, data): assert isinstance(data, (bytes, bytearray, memoryview)), \ "data argument must be byte-ish (%r)" % type(data) diff --git a/tests/test_web_sendfile.py b/tests/test_web_sendfile.py index 86d8f439ff2..ccc24d8099f 100644 --- a/tests/test_web_sendfile.py +++ b/tests/test_web_sendfile.py @@ -60,3 +60,18 @@ def test_static_handle_exception(loop): assert exc is fut.exception() assert not fake_loop.add_writer.called assert not fake_loop.remove_writer.called + + +def test__sendfile_cb_return_on_cancelling(loop): + fake_loop = mock.Mock() + with mock.patch('aiohttp.file_sender.os') as m_os: + out_fd = 30 + in_fd = 31 + fut = helpers.create_future(loop) + fut.cancel() + file_sender = FileSender() + file_sender._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False) + assert fut.done() + assert not fake_loop.add_writer.called + assert not fake_loop.remove_writer.called + assert not m_os.sendfile.called diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 5289b49aeb1..bf720c6f969 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -9,7 +9,7 @@ import aiohttp from aiohttp import log, request, web from aiohttp.file_sender import FileSender -from aiohttp.test_utils import unused_port +from aiohttp.test_utils import loop_context, unused_port try: import ssl @@ -17,6 +17,23 @@ ssl = False +try: + import uvloop +except: + uvloop = None + + +LOOP_FACTORIES = [asyncio.new_event_loop] +if uvloop: + LOOP_FACTORIES.append(uvloop.new_event_loop) + + +@pytest.yield_fixture(params=LOOP_FACTORIES) +def loop(request): + with loop_context(request.param) as loop: + yield loop + + @pytest.fixture(params=['sendfile', 'fallback'], ids=['sendfile', 'fallback']) def sender(request): def maker(*args, **kwargs):