From 678993a4bd071200fbcc14fa121c6673221aa540 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 4 Dec 2024 14:13:05 -0600 Subject: [PATCH] Fix race in `FileResponse` if file is replaced during `prepare` (#10101) --- CHANGES/10101.bugfix.rst | 1 + aiohttp/web_fileresponse.py | 79 ++++++++++++++++++++++++--------- tests/test_web_urldispatcher.py | 7 +-- 3 files changed, 63 insertions(+), 24 deletions(-) create mode 100644 CHANGES/10101.bugfix.rst diff --git a/CHANGES/10101.bugfix.rst b/CHANGES/10101.bugfix.rst new file mode 100644 index 00000000000..e06195ac028 --- /dev/null +++ b/CHANGES/10101.bugfix.rst @@ -0,0 +1 @@ +Fixed race condition in :class:`aiohttp.web.FileResponse` that could have resulted in an incorrect response if the file was replaced on the file system during ``prepare`` -- by :user:`bdraco`. diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index dacbb2b5892..eafc3b051cd 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -1,4 +1,5 @@ import asyncio +import io import os import pathlib import sys @@ -14,6 +15,7 @@ Callable, Final, Optional, + Set, Tuple, cast, ) @@ -70,6 +72,9 @@ CONTENT_TYPES.add_type(content_type, extension) +_CLOSE_FUTURES: Set[asyncio.Future[None]] = set() + + class FileResponse(StreamResponse): """A response object can be used to send files.""" @@ -158,10 +163,10 @@ async def _precondition_failed( self.content_length = 0 return await super().prepare(request) - def _get_file_path_stat_encoding( + def _open_file_path_stat_encoding( self, accept_encoding: str - ) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]: - """Return the file path, stat result, and encoding. + ) -> Tuple[Optional[io.BufferedReader], os.stat_result, Optional[str]]: + """Return the io object, stat result, and encoding. If an uncompressed file is returned, the encoding is set to :py:data:`None`. @@ -179,10 +184,27 @@ def _get_file_path_stat_encoding( # Do not follow symlinks and ignore any non-regular files. st = compressed_path.lstat() if S_ISREG(st.st_mode): - return compressed_path, st, file_encoding + fobj = compressed_path.open("rb") + with suppress(OSError): + # fstat() may not be available on all platforms + # Once we open the file, we want the fstat() to ensure + # the file has not changed between the first stat() + # and the open(). + st = os.stat(fobj.fileno()) + return fobj, st, file_encoding # Fallback to the uncompressed file - return file_path, file_path.stat(), None + st = file_path.stat() + if not S_ISREG(st.st_mode): + return None, st, None + fobj = file_path.open("rb") + with suppress(OSError): + # fstat() may not be available on all platforms + # Once we open the file, we want the fstat() to ensure + # the file has not changed between the first stat() + # and the open(). + st = os.stat(fobj.fileno()) + return fobj, st, None async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: loop = asyncio.get_running_loop() @@ -190,20 +212,44 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() try: - file_path, st, file_encoding = await loop.run_in_executor( - None, self._get_file_path_stat_encoding, accept_encoding + fobj, st, file_encoding = await loop.run_in_executor( + None, self._open_file_path_stat_encoding, accept_encoding ) + except PermissionError: + self.set_status(HTTPForbidden.status_code) + return await super().prepare(request) except OSError: # Most likely to be FileNotFoundError or OSError for circular # symlinks in python >= 3.13, so respond with 404. self.set_status(HTTPNotFound.status_code) return await super().prepare(request) - # Forbid special files like sockets, pipes, devices, etc. - if not S_ISREG(st.st_mode): - self.set_status(HTTPForbidden.status_code) - return await super().prepare(request) + try: + # Forbid special files like sockets, pipes, devices, etc. + if not fobj or not S_ISREG(st.st_mode): + self.set_status(HTTPForbidden.status_code) + return await super().prepare(request) + return await self._prepare_open_file(request, fobj, st, file_encoding) + finally: + if fobj: + # We do not await here because we do not want to wait + # for the executor to finish before returning the response + # so the connection can begin servicing another request + # as soon as possible. + close_future = loop.run_in_executor(None, fobj.close) + # Hold a strong reference to the future to prevent it from being + # garbage collected before it completes. + _CLOSE_FUTURES.add(close_future) + close_future.add_done_callback(_CLOSE_FUTURES.remove) + + async def _prepare_open_file( + self, + request: "BaseRequest", + fobj: io.BufferedReader, + st: os.stat_result, + file_encoding: Optional[str], + ) -> Optional[AbstractStreamWriter]: etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" last_modified = st.st_mtime @@ -346,18 +392,9 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter if count == 0 or must_be_empty_body(request.method, self.status): return await super().prepare(request) - try: - fobj = await loop.run_in_executor(None, file_path.open, "rb") - except PermissionError: - self.set_status(HTTPForbidden.status_code) - return await super().prepare(request) - if start: # be aware that start could be None or int=0 here. offset = start else: offset = 0 - try: - return await self._sendfile(request, fobj, offset, count) - finally: - await asyncio.shield(loop.run_in_executor(None, fobj.close)) + return await self._sendfile(request, fobj, offset, count) diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py index 5cd4aebdc55..d21ecaa101a 100644 --- a/tests/test_web_urldispatcher.py +++ b/tests/test_web_urldispatcher.py @@ -579,16 +579,17 @@ async def test_access_mock_special_resource( my_special.touch() real_result = my_special.stat() - real_stat = pathlib.Path.stat + real_stat = os.stat - def mock_stat(self: pathlib.Path, **kwargs: Any) -> os.stat_result: - s = real_stat(self, **kwargs) + def mock_stat(path: Any, **kwargs: Any) -> os.stat_result: + s = real_stat(path, **kwargs) if os.path.samestat(s, real_result): mock_mode = S_IFIFO | S_IMODE(s.st_mode) s = os.stat_result([mock_mode] + list(s)[1:]) return s monkeypatch.setattr("pathlib.Path.stat", mock_stat) + monkeypatch.setattr("os.stat", mock_stat) app = web.Application() app.router.add_static("/", str(tmp_path))