Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore 304 performance after fixing FileResponse replace race #10113

Merged
merged 8 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10113.bugfix.rst
163 changes: 91 additions & 72 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import sys
from contextlib import suppress
from enum import Enum, auto
from mimetypes import MimeTypes
from stat import S_ISREG
from types import MappingProxyType
Expand Down Expand Up @@ -66,6 +67,16 @@
}
)


class _FileResponseResult(Enum):
"""The result of the file response."""

SEND_FILE = auto() # Ie a regular file to send
NOT_ACCEPTABLE = auto() # Ie a socket, or non-regular file
PRE_CONDITION_FAILED = auto() # Ie If-Match or If-None-Match failed
NOT_MODIFIED = auto() # 304 Not Modified


# Add custom pairs and clear the encodings map so guess_type ignores them.
CONTENT_TYPES.encodings_map.clear()
for content_type, extension in ADDITIONAL_CONTENT_TYPES.items():
Expand Down Expand Up @@ -163,17 +174,65 @@ async def _precondition_failed(
self.content_length = 0
return await super().prepare(request)

def _open_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[Optional[io.BufferedReader], os.stat_result, Optional[str]]:
"""Return the io object, stat result, and encoding.
def _make_response(
self, request: "BaseRequest", accept_encoding: str
) -> Tuple[
_FileResponseResult, Optional[io.BufferedReader], os.stat_result, Optional[str]
]:
"""Return the response result, io object, stat result, and encoding.

If an uncompressed file is returned, the encoding is set to
:py:data:`None`.

This method should be called from a thread executor
since it calls os.stat which may block.
"""
file_path, st, file_encoding = self._get_file_path_stat_encoding(
accept_encoding
)
if not file_path:
return _FileResponseResult.NOT_ACCEPTABLE, None, st, None

etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
if (ifmatch := request.if_match) is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding

if (
(unmodsince := request.if_unmodified_since) is not None
and ifmatch is None
and st.st_mtime > unmodsince.timestamp()
):
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
if (ifnonematch := request.if_none_match) is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding

if (
(modsince := request.if_modified_since) is not None
and ifnonematch is None
and st.st_mtime <= modsince.timestamp()
):
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding

fobj = file_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return _FileResponseResult.SEND_FILE, fobj, st, file_encoding

def _get_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[Optional[pathlib.Path], os.stat_result, Optional[str]]:
file_path = self._path
for file_extension, file_encoding in ENCODING_EXTENSIONS.items():
if file_encoding not in accept_encoding:
Expand All @@ -184,36 +243,22 @@ def _open_file_path_stat_encoding(
# Do not follow symlinks and ignore any non-regular files.
st = compressed_path.lstat()
if S_ISREG(st.st_mode):
fobj = compressed_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, file_encoding
return compressed_path, st, file_encoding

# Fallback to the uncompressed file
st = file_path.stat()
if not S_ISREG(st.st_mode):
return None, st, None
fobj = file_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, None
return file_path, st, None

async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
try:
fobj, st, file_encoding = await loop.run_in_executor(
None, self._open_file_path_stat_encoding, accept_encoding
response_result, fobj, st, file_encoding = await loop.run_in_executor(
None, self._make_response, request, accept_encoding
)
except PermissionError:
self.set_status(HTTPForbidden.status_code)
Expand All @@ -224,24 +269,32 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter
self.set_status(HTTPNotFound.status_code)
return await super().prepare(request)

try:
# Forbid special files like sockets, pipes, devices, etc.
if not fobj or not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
# Forbid special files like sockets, pipes, devices, etc.
if response_result is _FileResponseResult.NOT_ACCEPTABLE:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

if response_result is _FileResponseResult.PRE_CONDITION_FAILED:
return await self._precondition_failed(request)

if response_result is _FileResponseResult.NOT_MODIFIED:
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime
return await self._not_modified(request, etag_value, last_modified)

assert fobj is not None
try:
return await self._prepare_open_file(request, fobj, st, file_encoding)
finally:
if fobj:
# We do not await here because we do not want to wait
# for the executor to finish before returning the response
# so the connection can begin servicing another request
# as soon as possible.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)
# We do not await here because we do not want to wait
# for the executor to finish before returning the response
# so the connection can begin servicing another request
# as soon as possible.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)

async def _prepare_open_file(
self,
Expand All @@ -250,43 +303,9 @@ async def _prepare_open_file(
st: os.stat_result,
file_encoding: Optional[str],
) -> Optional[AbstractStreamWriter]:
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
ifmatch = request.if_match
if ifmatch is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return await self._precondition_failed(request)

unmodsince = request.if_unmodified_since
if (
unmodsince is not None
and ifmatch is None
and st.st_mtime > unmodsince.timestamp()
):
return await self._precondition_failed(request)

# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
ifnonematch = request.if_none_match
if ifnonematch is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return await self._not_modified(request, etag_value, last_modified)

modsince = request.if_modified_since
if (
modsince is not None
and ifnonematch is None
and st.st_mtime <= modsince.timestamp()
):
return await self._not_modified(request, etag_value, last_modified)

status = self._status
file_size = st.st_size
count = file_size

start = None

ifrange = request.if_range
Expand Down Expand Up @@ -375,7 +394,7 @@ async def _prepare_open_file(
# compress.
self._compression = False

self.etag = etag_value # type: ignore[assignment]
self.etag = f"{st.st_mtime_ns:x}-{st.st_size:x}" # type: ignore[assignment]
self.last_modified = st.st_mtime # type: ignore[assignment]
self.content_length = count

Expand Down
Loading