diff --git a/multipart/multipart.py b/multipart/multipart.py index abe4dff..3adf9ac 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -15,7 +15,7 @@ from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError if TYPE_CHECKING: # pragma: no cover - from typing import Callable, TypedDict + from typing import Callable, Protocol, TypedDict class QuerystringCallbacks(TypedDict, total=False): on_field_start: Callable[[], None] @@ -55,6 +55,30 @@ class FileConfig(TypedDict, total=False): UPLOAD_KEEP_EXTENSIONS: bool MAX_MEMORY_FILE_SIZE: int + class _FormProtocol(Protocol): + def write(self, data: bytes) -> int: + ... + + def finalize(self) -> None: + ... + + def close(self) -> None: + ... + + class FieldProtocol(_FormProtocol, Protocol): + def __init__(self, name: bytes) -> None: + ... + + def set_none(self) -> None: + ... + + class FileProtocol(_FormProtocol, Protocol): + def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: + ... + + OnFieldCallback = Callable[[FieldProtocol], None] + OnFileCallback = Callable[[FieldProtocol], None] + # Unique missing object. _missing = object() @@ -190,7 +214,7 @@ class Field: :param name: the name of the form field """ - def __init__(self, name: str): + def __init__(self, name: bytes): self._name = name self._value: list[bytes] = [] @@ -198,7 +222,7 @@ def __init__(self, name: str): self._cache = _missing @classmethod - def from_value(cls, name: str, value: bytes | None) -> Field: + def from_value(cls, name: bytes, value: bytes | None) -> Field: """Create an instance of a :class:`Field`, and set the corresponding value - either None or an actual value. This method will also finalize the Field itself. @@ -260,7 +284,7 @@ def set_none(self) -> None: self._cache = None @property - def field_name(self) -> str: + def field_name(self) -> bytes: """This property returns the name of the field.""" return self._name @@ -1562,6 +1586,7 @@ class FormParser: field_instance.write(data) field_instance.finalize() field_instance.close() + field_instance.set_none() :param config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value, and then @@ -1584,14 +1609,14 @@ class FormParser: def __init__( self, - content_type, - on_field, - on_file, - on_end=None, - boundary=None, - file_name=None, - FileClass=File, - FieldClass=Field, + content_type: str, + on_field: OnFieldCallback, + on_file: OnFileCallback, + on_end: Callable[[], None] | None = None, + boundary: bytes | str | None = None, + file_name: bytes | None = None, + FileClass: type[FileProtocol] = File, + FieldClass: type[FieldProtocol] = Field, config: FormParserConfig = {}, ): self.logger = logging.getLogger(__name__) @@ -1617,22 +1642,22 @@ def __init__( # Depending on the Content-Type, we instantiate the correct parser. if content_type == "application/octet-stream": - # Work around the lack of 'nonlocal' in Py2 - class vars: - f = None + f: FileProtocol | None = None def on_start() -> None: - vars.f = FileClass(file_name, None, config=self.config) + nonlocal f + f = FileClass(file_name, None, config=self.config) def on_data(data: bytes, start: int, end: int) -> None: - vars.f.write(data[start:end]) + nonlocal f + f.write(data[start:end]) - def on_end() -> None: + def _on_end() -> None: # Finalize the file itself. - vars.f.finalize() + f.finalize() # Call our callback. - on_file(vars.f) + on_file(f) # Call the on-end callback. if self.on_end is not None: @@ -1640,15 +1665,14 @@ def on_end() -> None: # Instantiate an octet-stream parser parser = OctetStreamParser( - callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end}, + callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end}, max_size=self.config["MAX_BODY_SIZE"], ) elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded": name_buffer: list[bytes] = [] - class vars: - f = None + f: FieldProtocol | None = None def on_field_start() -> None: pass @@ -1657,25 +1681,27 @@ def on_field_name(data: bytes, start: int, end: int) -> None: name_buffer.append(data[start:end]) def on_field_data(data: bytes, start: int, end: int) -> None: - if vars.f is None: - vars.f = FieldClass(b"".join(name_buffer)) + nonlocal f + if f is None: + f = FieldClass(b"".join(name_buffer)) del name_buffer[:] - vars.f.write(data[start:end]) + f.write(data[start:end]) def on_field_end() -> None: + nonlocal f # Finalize and call callback. - if vars.f is None: + if f is None: # If we get here, it's because there was no field data. # We create a field, set it to None, and then continue. - vars.f = FieldClass(b"".join(name_buffer)) + f = FieldClass(b"".join(name_buffer)) del name_buffer[:] - vars.f.set_none() + f.set_none() - vars.f.finalize() - on_field(vars.f) - vars.f = None + f.finalize() + on_field(f) + f = None - def on_end() -> None: + def _on_end() -> None: if self.on_end is not None: self.on_end() @@ -1686,7 +1712,7 @@ def on_end() -> None: "on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end, - "on_end": on_end, + "on_end": _on_end, }, max_size=self.config["MAX_BODY_SIZE"], ) @@ -1700,26 +1726,26 @@ def on_end() -> None: header_value: list[bytes] = [] headers = {} - # No 'nonlocal' on Python 2 :-( - class vars: - f = None - writer = None - is_file = False + f: FileProtocol | FieldProtocol | None = None + writer = None + is_file = False def on_part_begin(): pass - def on_part_data(data: bytes, start: int, end: int): - bytes_processed = vars.writer.write(data[start:end]) + def on_part_data(data: bytes, start: int, end: int) -> None: + nonlocal writer + bytes_processed = writer.write(data[start:end]) # TODO: check for error here. return bytes_processed def on_part_end() -> None: - vars.f.finalize() - if vars.is_file: - on_file(vars.f) + nonlocal f, is_file + f.finalize() + if is_file: + on_file(f) else: - on_field(vars.f) + on_field(f) def on_header_field(data: bytes, start: int, end: int): header_name.append(data[start:end]) @@ -1733,8 +1759,9 @@ def on_header_end(): del header_value[:] def on_headers_finished() -> None: + nonlocal is_file, f, writer # Reset the 'is file' flag. - vars.is_file = False + is_file = False # Parse the content-disposition header. # TODO: handle mixed case @@ -1748,10 +1775,10 @@ def on_headers_finished() -> None: # Create the proper class. if file_name is None: - vars.f = FieldClass(field_name) + f = FieldClass(field_name) else: - vars.f = FileClass(file_name, field_name, config=self.config) - vars.is_file = True + f = FileClass(file_name, field_name, config=self.config) + is_file = True # Parse the given Content-Transfer-Encoding to determine what # we need to do with the incoming data. @@ -1759,13 +1786,13 @@ def on_headers_finished() -> None: transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") if transfer_encoding == b"binary" or transfer_encoding == b"8bit" or transfer_encoding == b"7bit": - vars.writer = vars.f + writer = f elif transfer_encoding == b"base64": - vars.writer = Base64Decoder(vars.f) + writer = Base64Decoder(f) elif transfer_encoding == b"quoted-printable": - vars.writer = QuotedPrintableDecoder(vars.f) + writer = QuotedPrintableDecoder(f) else: self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding) @@ -1774,10 +1801,11 @@ def on_headers_finished() -> None: else: # If we aren't erroring, then we just treat this as an # unencoded Content-Transfer-Encoding. - vars.writer = vars.f + writer = f - def on_end() -> None: - vars.writer.finalize() + def _on_end() -> None: + nonlocal writer + writer.finalize() if self.on_end is not None: self.on_end() @@ -1792,7 +1820,7 @@ def on_end() -> None: "on_header_value": on_header_value, "on_header_end": on_header_end, "on_headers_finished": on_headers_finished, - "on_end": on_end, + "on_end": _on_end, }, max_size=self.config["MAX_BODY_SIZE"], )