Skip to content

Commit

Permalink
Fix race in FileResponse when the file is changed between the stat an…
Browse files Browse the repository at this point in the history
…d open calls

There was a race in ``FileResponse`` where the stat would be incorrect
if the file was changed out between the `stat` and `open` syscalls.
This would lead to various unexpected behaviors such as trying to read
beyond the length of the file or sending a partial file. This problem
is likely to occour when files are being renamed/linked into place.

An example of how this can happen with a system that provides weather
data every 60s:

An external process writes `.weather.txt` at the top of
each minute, and than renames it to `weather.txt`. In this
case `aiohttp` may stat the old `weather.txt`, and than
open the new `weather.txt`, and use the `stat` result from
the original file.

To fix this we now `fstat` the open file on operating systems
where `fstat` is available

fixes #8013
  • Loading branch information
bdraco committed Dec 4, 2024
1 parent fcce1bf commit f053352
Showing 1 changed file with 184 additions and 156 deletions.
340 changes: 184 additions & 156 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import io
import os
import pathlib
from contextlib import suppress
Expand All @@ -13,6 +14,7 @@
Callable,
Final,
Optional,
Set,
Tuple,
cast,
)
Expand Down Expand Up @@ -69,6 +71,24 @@
CONTENT_TYPES.add_type(content_type, extension)


_CLOSE_FUTURES: Set[asyncio.Future[None]] = set()


def _stat_open_file(
file_path: pathlib.Path, fobj: io.BufferedReader, st: Optional[os.stat_result]
) -> os.stat_result:
"""Return the stat result of the file or the fallback stat result.
Ideally we can use fstat() to get the stat result of the file object,
to ensure we are returning the correct length of the open file,
but it is not possible to get the file descriptor from the file object
on some operating systems, so we have to use the stat result of the file path.
"""
with suppress(OSError): # May not work on Windows
st = os.stat(fobj.fileno())
return st or file_path.stat()


class FileResponse(StreamResponse):
"""A response object can be used to send files."""

Expand Down Expand Up @@ -157,10 +177,10 @@ async def _precondition_failed(
self.content_length = 0
return await super().prepare(request)

def _get_file_path_stat_encoding(
def _open_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]:
"""Return the file path, stat result, and encoding.
) -> Tuple[io.BufferedReader, os.stat_result, Optional[str]]:
"""Return the io object, stat result, and encoding.
If an uncompressed file is returned, the encoding is set to
:py:data:`None`.
Expand All @@ -178,183 +198,191 @@ def _get_file_path_stat_encoding(
# Do not follow symlinks and ignore any non-regular files.
st = compressed_path.lstat()
if S_ISREG(st.st_mode):
return compressed_path, st, file_encoding
fobj = compressed_path.open("rb")
st = _stat_open_file(compressed_path, fobj, st)
return fobj, st, file_encoding

# Fallback to the uncompressed file
return file_path, file_path.stat(), None
fobj = file_path.open("rb")
st = _stat_open_file(file_path, fobj, None)
return fobj, st, None

async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
try:
file_path, st, file_encoding = await loop.run_in_executor(
None, self._get_file_path_stat_encoding, accept_encoding
fobj, st, file_encoding = await loop.run_in_executor(
None, self._open_file_path_stat_encoding, accept_encoding
)
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
except OSError:
# Most likely to be FileNotFoundError or OSError for circular
# symlinks in python >= 3.13, so respond with 404.
self.set_status(HTTPNotFound.status_code)
return await super().prepare(request)

# Forbid special files like sockets, pipes, devices, etc.
if not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
ifmatch = request.if_match
if ifmatch is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return await self._precondition_failed(request)

unmodsince = request.if_unmodified_since
if (
unmodsince is not None
and ifmatch is None
and st.st_mtime > unmodsince.timestamp()
):
return await self._precondition_failed(request)

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
ifnonematch = request.if_none_match
if ifnonematch is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return await self._not_modified(request, etag_value, last_modified)

modsince = request.if_modified_since
if (
modsince is not None
and ifnonematch is None
and st.st_mtime <= modsince.timestamp()
):
return await self._not_modified(request, etag_value, last_modified)

status = self._status
file_size = st.st_size
count = file_size

start = None

ifrange = request.if_range
if ifrange is None or st.st_mtime <= ifrange.timestamp():
# If-Range header check:
# condition = cached date >= last modification date
# return 206 if True else 200.
# if False:
# Range header would not be processed, return 200
# if True but Range header missing
# return 200
try:
rng = request.http_range
start = rng.start
end = rng.stop
except ValueError:
# https://tools.ietf.org/html/rfc7233:
# A server generating a 416 (Range Not Satisfiable) response to
# a byte-range request SHOULD send a Content-Range header field
# with an unsatisfied-range value.
# The complete-length in a 416 response indicates the current
# length of the selected representation.
#
# Will do the same below. Many servers ignore this and do not
# send a Content-Range header with HTTP 416
self.headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}"
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
try:
# Forbid special files like sockets, pipes, devices, etc.
if not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

# If a range request has been made, convert start, end slice
# notation into file pointer offset and count
if start is not None or end is not None:
if start < 0 and end is None: # return tail of file
start += file_size
if start < 0:
# if Range:bytes=-1000 in request header but file size
# is only 200, there would be trouble without this
start = 0
count = file_size - start
else:
# rfc7233:If the last-byte-pos value is
# absent, or if the value is greater than or equal to
# the current length of the representation data,
# the byte range is interpreted as the remainder
# of the representation (i.e., the server replaces the
# value of last-byte-pos with a value that is one less than
# the current length of the selected representation).
count = (
min(end if end is not None else file_size, file_size) - start
)

if start >= file_size:
# HTTP 416 should be returned in this case.
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
ifmatch = request.if_match
if ifmatch is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return await self._precondition_failed(request)

unmodsince = request.if_unmodified_since
if (
unmodsince is not None
and ifmatch is None
and st.st_mtime > unmodsince.timestamp()
):
return await self._precondition_failed(request)

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
ifnonematch = request.if_none_match
if ifnonematch is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return await self._not_modified(request, etag_value, last_modified)

modsince = request.if_modified_since
if (
modsince is not None
and ifnonematch is None
and st.st_mtime <= modsince.timestamp()
):
return await self._not_modified(request, etag_value, last_modified)

status = self._status
file_size = st.st_size
count = file_size

start = None

ifrange = request.if_range
if ifrange is None or st.st_mtime <= ifrange.timestamp():
# If-Range header check:
# condition = cached date >= last modification date
# return 206 if True else 200.
# if False:
# Range header would not be processed, return 200
# if True but Range header missing
# return 200
try:
rng = request.http_range
start = rng.start
end = rng.stop
except ValueError:
# https://tools.ietf.org/html/rfc7233:
# A server generating a 416 (Range Not Satisfiable) response to
# a byte-range request SHOULD send a Content-Range header field
# with an unsatisfied-range value.
# The complete-length in a 416 response indicates the current
# length of the selected representation.
#
# According to https://tools.ietf.org/html/rfc7233:
# If a valid byte-range-set includes at least one
# byte-range-spec with a first-byte-pos that is less than
# the current length of the representation, or at least one
# suffix-byte-range-spec with a non-zero suffix-length,
# then the byte-range-set is satisfiable. Otherwise, the
# byte-range-set is unsatisfiable.
# Will do the same below. Many servers ignore this and do not
# send a Content-Range header with HTTP 416
self.headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}"
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)

status = HTTPPartialContent.status_code
# Even though you are sending the whole file, you should still
# return a HTTP 206 for a Range request.
self.set_status(status)

# If the Content-Type header is not already set, guess it based on the
# extension of the request path. The encoding returned by guess_type
# can be ignored since the map was cleared above.
if hdrs.CONTENT_TYPE not in self.headers:
self.content_type = (
CONTENT_TYPES.guess_type(self._path)[0] or FALLBACK_CONTENT_TYPE
)

if file_encoding:
self.headers[hdrs.CONTENT_ENCODING] = file_encoding
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
# Disable compression if we are already sending
# a compressed file since we don't want to double
# compress.
self._compression = False

self.etag = etag_value # type: ignore[assignment]
self.last_modified = st.st_mtime # type: ignore[assignment]
self.content_length = count

self.headers[hdrs.ACCEPT_RANGES] = "bytes"

real_start = cast(int, start)

if status == HTTPPartialContent.status_code:
self.headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format(
real_start, real_start + count - 1, file_size
)

# If we are sending 0 bytes calling sendfile() will throw a ValueError
if count == 0 or must_be_empty_body(request.method, self.status):
return await super().prepare(request)

try:
fobj = await loop.run_in_executor(None, file_path.open, "rb")
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
# If a range request has been made, convert start, end slice
# notation into file pointer offset and count
if start is not None or end is not None:
if start < 0 and end is None: # return tail of file
start += file_size
if start < 0:
# if Range:bytes=-1000 in request header but file size
# is only 200, there would be trouble without this
start = 0
count = file_size - start
else:
# rfc7233:If the last-byte-pos value is
# absent, or if the value is greater than or equal to
# the current length of the representation data,
# the byte range is interpreted as the remainder
# of the representation (i.e., the server replaces the
# value of last-byte-pos with a value that is one less than
# the current length of the selected representation).
count = (
min(end if end is not None else file_size, file_size)
- start
)

if start >= file_size:
# HTTP 416 should be returned in this case.
#
# According to https://tools.ietf.org/html/rfc7233:
# If a valid byte-range-set includes at least one
# byte-range-spec with a first-byte-pos that is less than
# the current length of the representation, or at least one
# suffix-byte-range-spec with a non-zero suffix-length,
# then the byte-range-set is satisfiable. Otherwise, the
# byte-range-set is unsatisfiable.
self.headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}"
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)

status = HTTPPartialContent.status_code
# Even though you are sending the whole file, you should still
# return a HTTP 206 for a Range request.
self.set_status(status)

# If the Content-Type header is not already set, guess it based on the
# extension of the request path. The encoding returned by guess_type
# can be ignored since the map was cleared above.
if hdrs.CONTENT_TYPE not in self.headers:
self.content_type = (
CONTENT_TYPES.guess_type(self._path)[0] or FALLBACK_CONTENT_TYPE
)

if file_encoding:
self.headers[hdrs.CONTENT_ENCODING] = file_encoding
self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
# Disable compression if we are already sending
# a compressed file since we don't want to double
# compress.
self._compression = False

self.etag = etag_value # type: ignore[assignment]
self.last_modified = st.st_mtime # type: ignore[assignment]
self.content_length = count

self.headers[hdrs.ACCEPT_RANGES] = "bytes"

real_start = cast(int, start)

if status == HTTPPartialContent.status_code:
self.headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format(
real_start, real_start + count - 1, file_size
)

# If we are sending 0 bytes calling sendfile() will throw a ValueError
if count == 0 or must_be_empty_body(request.method, self.status):
return await super().prepare(request)

if start: # be aware that start could be None or int=0 here.
offset = start
else:
offset = 0
if start: # be aware that start could be None or int=0 here.
offset = start
else:
offset = 0

try:
return await self._sendfile(request, fobj, offset, count)
finally:
await asyncio.shield(loop.run_in_executor(None, fobj.close))
# We do not await here because we do not want to wait
# for the executor to finish before returning the response.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)

0 comments on commit f053352

Please sign in to comment.