Skip to content

Commit

Permalink
Make the file argument to UploadFile required (#1413)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <[email protected]>
  • Loading branch information
adriangb and Kludex authored Feb 4, 2023
1 parent ea2e794 commit 3697c8d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 33 deletions.
21 changes: 8 additions & 13 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import tempfile
import typing
from collections.abc import Sequence
from shlex import shlex
Expand Down Expand Up @@ -428,28 +427,24 @@ class UploadFile:
An uploaded file included as part of the request data.
"""

spool_max_size = 1024 * 1024
file: typing.BinaryIO
headers: "Headers"

def __init__(
self,
filename: str,
file: typing.Optional[typing.BinaryIO] = None,
content_type: str = "",
file: typing.BinaryIO,
*,
filename: typing.Optional[str] = None,
headers: "typing.Optional[Headers]" = None,
) -> None:
self.filename = filename
self.content_type = content_type
if file is None:
self.file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) # type: ignore[assignment] # noqa: E501
else:
self.file = file
self.file = file
self.headers = headers or Headers()

@property
def content_type(self) -> typing.Optional[str]:
return self.headers.get("content-type", None)

@property
def _in_memory(self) -> bool:
# check for SpooledTemporaryFile._rolled
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk

Expand Down
12 changes: 6 additions & 6 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
from enum import Enum
from tempfile import SpooledTemporaryFile
from urllib.parse import unquote_plus

from starlette.datastructures import FormData, Headers, UploadFile
Expand Down Expand Up @@ -116,6 +117,8 @@ async def parse(self) -> FormData:


class MultiPartParser:
max_file_size = 1024 * 1024

def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
) -> None:
Expand Down Expand Up @@ -160,7 +163,7 @@ def on_end(self) -> None:

async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
content_type, params = parse_options_header(self.headers["Content-Type"])
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if type(charset) == bytes:
charset = charset.decode("latin-1")
Expand All @@ -186,7 +189,6 @@ async def parse(self) -> FormData:
header_field = b""
header_value = b""
content_disposition = None
content_type = b""
field_name = ""
data = b""
file: typing.Optional[UploadFile] = None
Expand All @@ -202,7 +204,6 @@ async def parse(self) -> FormData:
for message_type, message_bytes in messages:
if message_type == MultiPartMessage.PART_BEGIN:
content_disposition = None
content_type = b""
data = b""
item_headers = []
elif message_type == MultiPartMessage.HEADER_FIELD:
Expand All @@ -213,8 +214,6 @@ async def parse(self) -> FormData:
field = header_field.lower()
if field == b"content-disposition":
content_disposition = header_value
elif field == b"content-type":
content_type = header_value
item_headers.append((field, header_value))
header_field = b""
header_value = b""
Expand All @@ -229,9 +228,10 @@ async def parse(self) -> FormData:
)
if b"filename" in options:
filename = _user_safe_decode(options[b"filename"], charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
file = UploadFile(
file=tempfile, # type: ignore[arg-type]
filename=filename,
content_type=content_type.decode("latin-1"),
headers=Headers(raw=item_headers),
)
else:
Expand Down
38 changes: 24 additions & 14 deletions tests/test_datastructures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
from tempfile import SpooledTemporaryFile
from typing import BinaryIO

import pytest

Expand Down Expand Up @@ -269,20 +271,6 @@ def test_queryparams():
assert QueryParams(q) == q


class BigUploadFile(UploadFile):
spool_max_size = 1024


@pytest.mark.anyio
async def test_upload_file():
big_file = BigUploadFile("big-file")
await big_file.write(b"big-data" * 512)
await big_file.write(b"big-data")
await big_file.seek(0)
assert await big_file.read(1024) == b"big-data" * 128
await big_file.close()


@pytest.mark.anyio
async def test_upload_file_file_input():
"""Test passing file/stream into the UploadFile constructor"""
Expand All @@ -295,6 +283,28 @@ async def test_upload_file_file_input():
assert await file.read() == b"data and more data!"


@pytest.mark.anyio
@pytest.mark.parametrize("max_size", [1, 1024], ids=["rolled", "unrolled"])
async def test_uploadfile_rolling(max_size: int) -> None:
"""Test that we can r/w to a SpooledTemporaryFile
managed by UploadFile before and after it rolls to disk
"""
stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment]
max_size=max_size
)
file = UploadFile(filename="file", file=stream)
assert await file.read() == b""
await file.write(b"data")
assert await file.read() == b""
await file.seek(0)
assert await file.read() == b"data"
await file.write(b" more")
assert await file.read() == b""
await file.seek(0)
assert await file.read() == b"data more"
await file.close()


def test_formdata():
stream = io.BytesIO(b"data")
upload = UploadFile(filename="file", file=stream)
Expand Down

0 comments on commit 3697c8d

Please sign in to comment.