diff --git a/.github/workflows/test.yaml b/.github/workflows/main.yaml similarity index 85% rename from .github/workflows/test.yaml rename to .github/workflows/main.yaml index 7448d91..7b4fd5c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/main.yaml @@ -26,6 +26,10 @@ jobs: run: | python -m pip install --upgrade pip pip install .[dev] + - name: Lint + if: matrix.python-version == '3.8' + run: | + ruff multipart tests - name: Test with pytest run: | inv test diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 221b67f..cc38611 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -18,12 +18,10 @@ permissions: jobs: deploy: - runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Set up Python 3.10 uses: actions/setup-python@v5 with: diff --git a/multipart/__init__.py b/multipart/__init__.py index 28c7ad6..3c8a2e8 100644 --- a/multipart/__init__.py +++ b/multipart/__init__.py @@ -4,12 +4,13 @@ __copyright__ = "Copyright (c) 2012-2013, Andrew Dunham" __version__ = "0.0.8" +from .multipart import FormParser, MultipartParser, OctetStreamParser, QuerystringParser, create_form_parser, parse_form -from .multipart import ( - FormParser, - MultipartParser, - OctetStreamParser, - QuerystringParser, - create_form_parser, - parse_form, +__all__ = ( + "FormParser", + "MultipartParser", + "OctetStreamParser", + "QuerystringParser", + "create_form_parser", + "parse_form", ) diff --git a/multipart/decoders.py b/multipart/decoders.py index 0d7ab32..417650c 100644 --- a/multipart/decoders.py +++ b/multipart/decoders.py @@ -59,8 +59,7 @@ def write(self, data): try: decoded = base64.b64decode(val) except binascii.Error: - raise DecodeError('There was an error raised while decoding ' - 'base64-encoded data.') + raise DecodeError("There was an error raised while decoding base64-encoded data.") self.underlying.write(decoded) @@ -69,7 +68,7 @@ def write(self, data): if remaining_len > 0: self.cache = data[-remaining_len:] else: - self.cache = b'' + self.cache = b"" # Return the length of the data to indicate no error. return len(data) @@ -78,7 +77,7 @@ def close(self): """Close this decoder. If the underlying object has a `close()` method, this function will call it. """ - if hasattr(self.underlying, 'close'): + if hasattr(self.underlying, "close"): self.underlying.close() def finalize(self): @@ -91,11 +90,11 @@ def finalize(self): call it. """ if len(self.cache) > 0: - raise DecodeError('There are %d bytes remaining in the ' - 'Base64Decoder cache when finalize() is called' - % len(self.cache)) + raise DecodeError( + "There are %d bytes remaining in the Base64Decoder cache when finalize() is called" % len(self.cache) + ) - if hasattr(self.underlying, 'finalize'): + if hasattr(self.underlying, "finalize"): self.underlying.finalize() def __repr__(self): @@ -111,8 +110,9 @@ class QuotedPrintableDecoder: :param underlying: the underlying object to pass writes to """ + def __init__(self, underlying): - self.cache = b'' + self.cache = b"" self.underlying = underlying def write(self, data): @@ -128,11 +128,11 @@ def write(self, data): # If the last 2 characters have an '=' sign in it, then we won't be # able to decode the encoded value and we'll need to save it for the # next decoding step. - if data[-2:].find(b'=') != -1: + if data[-2:].find(b"=") != -1: enc, rest = data[:-2], data[-2:] else: enc = data - rest = b'' + rest = b"" # Encode and write, if we have data. if len(enc) > 0: @@ -146,7 +146,7 @@ def close(self): """Close this decoder. If the underlying object has a `close()` method, this function will call it. """ - if hasattr(self.underlying, 'close'): + if hasattr(self.underlying, "close"): self.underlying.close() def finalize(self): @@ -161,10 +161,10 @@ def finalize(self): # If we have a cache, write and then remove it. if len(self.cache) > 0: self.underlying.write(binascii.a2b_qp(self.cache)) - self.cache = b'' + self.cache = b"" # Finalize our underlying stream. - if hasattr(self.underlying, 'finalize'): + if hasattr(self.underlying, "finalize"): self.underlying.finalize() def __repr__(self): diff --git a/multipart/exceptions.py b/multipart/exceptions.py index 016e7f7..cc3671f 100644 --- a/multipart/exceptions.py +++ b/multipart/exceptions.py @@ -1,6 +1,5 @@ class FormParserError(ValueError): """Base error class for our form parser.""" - pass class ParseError(FormParserError): @@ -17,30 +16,19 @@ class MultipartParseError(ParseError): """This is a specific error that is raised when the MultipartParser detects an error while parsing. """ - pass class QuerystringParseError(ParseError): """This is a specific error that is raised when the QuerystringParser detects an error while parsing. """ - pass class DecodeError(ParseError): """This exception is raised when there is a decoding error - for example with the Base64Decoder or QuotedPrintableDecoder. """ - pass - - -# On Python 3.3, IOError is the same as OSError, so we don't want to inherit -# from both of them. We handle this case below. -if IOError is not OSError: # pragma: no cover - class FileError(FormParserError, IOError, OSError): - """Exception class for problems with the File class.""" - pass -else: # pragma: no cover - class FileError(FormParserError, OSError): - """Exception class for problems with the File class.""" - pass + + +class FileError(FormParserError, OSError): + """Exception class for problems with the File class.""" diff --git a/multipart/multipart.py b/multipart/multipart.py index 73910da..a427f14 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1,70 +1,82 @@ -from .decoders import * -from .exceptions import * +from __future__ import annotations +import logging import os -import sys import shutil -import logging +import sys import tempfile +from email.message import Message +from enum import IntEnum from io import BytesIO from numbers import Number -from email.message import Message -from typing import Dict, Union, Tuple +from typing import Dict, Tuple, Union + +from .decoders import Base64Decoder, QuotedPrintableDecoder +from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError # Unique missing object. _missing = object() # States for the querystring parser. STATE_BEFORE_FIELD = 0 -STATE_FIELD_NAME = 1 -STATE_FIELD_DATA = 2 - -# States for the multipart parser -STATE_START = 0 -STATE_START_BOUNDARY = 1 -STATE_HEADER_FIELD_START = 2 -STATE_HEADER_FIELD = 3 -STATE_HEADER_VALUE_START = 4 -STATE_HEADER_VALUE = 5 -STATE_HEADER_VALUE_ALMOST_DONE = 6 -STATE_HEADERS_ALMOST_DONE = 7 -STATE_PART_DATA_START = 8 -STATE_PART_DATA = 9 -STATE_PART_DATA_END = 10 -STATE_END = 11 - -STATES = [ - "START", - "START_BOUNDARY", "HEADER_FIELD_START", "HEADER_FIELD", "HEADER_VALUE_START", "HEADER_VALUE", - "HEADER_VALUE_ALMOST_DONE", "HEADRES_ALMOST_DONE", "PART_DATA_START", "PART_DATA", "PART_DATA_END", "END" -] +STATE_FIELD_NAME = 1 +STATE_FIELD_DATA = 2 + + +class MultipartState(IntEnum): + """Multipart parser states. + + These are used to keep track of the state of the parser, and are used to determine + what to do when new data is encountered. + """ + + START = 0 + START_BOUNDARY = 1 + HEADER_FIELD_START = 2 + HEADER_FIELD = 3 + HEADER_VALUE_START = 4 + HEADER_VALUE = 5 + HEADER_VALUE_ALMOST_DONE = 6 + HEADERS_ALMOST_DONE = 7 + PART_DATA_START = 8 + PART_DATA = 9 + PART_DATA_END = 10 + END = 11 # Flags for the multipart parser. -FLAG_PART_BOUNDARY = 1 -FLAG_LAST_BOUNDARY = 2 +FLAG_PART_BOUNDARY = 1 +FLAG_LAST_BOUNDARY = 2 # Get constants. Since iterating over a str on Python 2 gives you a 1-length # string, but iterating over a bytes object on Python 3 gives you an integer, # we need to save these constants. -CR = b'\r'[0] -LF = b'\n'[0] -COLON = b':'[0] -SPACE = b' '[0] -HYPHEN = b'-'[0] -AMPERSAND = b'&'[0] -SEMICOLON = b';'[0] -LOWER_A = b'a'[0] -LOWER_Z = b'z'[0] -NULL = b'\x00'[0] +CR = b"\r"[0] +LF = b"\n"[0] +COLON = b":"[0] +SPACE = b" "[0] +HYPHEN = b"-"[0] +AMPERSAND = b"&"[0] +SEMICOLON = b";"[0] +LOWER_A = b"a"[0] +LOWER_Z = b"z"[0] +NULL = b"\x00"[0] + # Lower-casing a character is different, because of the difference between # str on Py2, and bytes on Py3. Same with getting the ordinal value of a byte, # and joining a list of bytes together. # These functions abstract that. -lower_char = lambda c: c | 0x20 -ord_char = lambda c: c -join_bytes = lambda b: bytes(list(b)) +def lower_char(c): + return c | 0x20 + + +def ord_char(c): + return c + + +def join_bytes(b): + return bytes(list(b)) def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, bytes]]: @@ -75,27 +87,27 @@ def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, b # Uses email.message.Message to parse the header as described in PEP 594. # Ref: https://peps.python.org/pep-0594/#cgi if not value: - return (b'', {}) + return (b"", {}) # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1. if isinstance(value, bytes): # pragma: no cover - value = value.decode('latin-1') + value = value.decode("latin-1") # For types - assert isinstance(value, str), 'Value should be a string by now' + assert isinstance(value, str), "Value should be a string by now" # If we have no options, return the string as-is. if ";" not in value: - return (value.lower().strip().encode('latin-1'), {}) + return (value.lower().strip().encode("latin-1"), {}) # Split at the first semicolon, to get our value and then options. # ctype, rest = value.split(b';', 1) message = Message() - message['content-type'] = value + message["content-type"] = value params = message.get_params() # If there were no parameters, this would have already returned above - assert params, 'At least the content type value should be present' - ctype = params.pop(0)[0].encode('latin-1') + assert params, "At least the content type value should be present" + ctype = params.pop(0)[0].encode("latin-1") options = {} for param in params: key, value = param @@ -106,10 +118,10 @@ def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, b value = value[-1] # If the value is a filename, we need to fix a bug on IE6 that sends # the full file path instead of the filename. - if key == 'filename': - if value[1:3] == ':\\' or value[:2] == '\\\\': - value = value.split('\\')[-1] - options[key.encode('latin-1')] = value.encode('latin-1') + if key == "filename": + if value[1:3] == ":\\" or value[:2] == "\\\\": + value = value.split("\\")[-1] + options[key.encode("latin-1")] = value.encode("latin-1") return ctype, options @@ -128,6 +140,7 @@ class Field: :param name: the name of the form field """ + def __init__(self, name): self._name = name self._value = [] @@ -172,22 +185,19 @@ def on_data(self, data): return len(data) def on_end(self): - """This method is called whenever the Field is finalized. - """ + """This method is called whenever the Field is finalized.""" if self._cache is _missing: - self._cache = b''.join(self._value) + self._cache = b"".join(self._value) def finalize(self): - """Finalize the form field. - """ + """Finalize the form field.""" self.on_end() def close(self): - """Close the Field object. This will free any underlying cache. - """ + """Close the Field object. This will free any underlying cache.""" # Free our value array. if self._cache is _missing: - self._cache = b''.join(self._value) + self._cache = b"".join(self._value) del self._value @@ -209,16 +219,13 @@ def field_name(self): def value(self): """This property returns the value of the form field.""" if self._cache is _missing: - self._cache = b''.join(self._value) + self._cache = b"".join(self._value) return self._cache def __eq__(self, other): if isinstance(other, Field): - return ( - self.field_name == other.field_name and - self.value == other.value - ) + return self.field_name == other.field_name and self.value == other.value else: return NotImplemented @@ -230,11 +237,7 @@ def __repr__(self): else: v = repr(self.value) - return "{}(field_name={!r}, value={})".format( - self.__class__.__name__, - self.field_name, - v - ) + return "{}(field_name={!r}, value={})".format(self.__class__.__name__, self.field_name, v) class File: @@ -296,6 +299,7 @@ class File: :param config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ + def __init__(self, file_name, field_name=None, config={}): # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) @@ -327,8 +331,7 @@ def field_name(self): @property def file_name(self): - """The file name given in the upload request. - """ + """The file name given in the upload request.""" return self._file_name @property @@ -369,9 +372,7 @@ def flush_to_disk(self): warning will be logged to this module's logger. """ if not self._in_memory: - self.logger.warning( - "Trying to flush to disk when we're not in memory" - ) + self.logger.warning("Trying to flush to disk when we're not in memory") return # Go back to the start of our file. @@ -397,14 +398,13 @@ def flush_to_disk(self): old_fileobj.close() def _get_disk_file(self): - """This function is responsible for getting a file object on-disk for us. - """ + """This function is responsible for getting a file object on-disk for us.""" self.logger.info("Opening a file on disk") - file_dir = self._config.get('UPLOAD_DIR') - keep_filename = self._config.get('UPLOAD_KEEP_FILENAME', False) - keep_extensions = self._config.get('UPLOAD_KEEP_EXTENSIONS', False) - delete_tmp = self._config.get('UPLOAD_DELETE_TMP', True) + file_dir = self._config.get("UPLOAD_DIR") + keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False) + keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False) + delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True) # If we have a directory and are to keep the filename... if file_dir is not None and keep_filename: @@ -419,8 +419,8 @@ def _get_disk_file(self): path = os.path.join(file_dir, fname) try: self.logger.info("Opening file: %r", path) - tmp_file = open(path, 'w+b') - except OSError as e: + tmp_file = open(path, "w+b") + except OSError: tmp_file = None self.logger.exception("Error opening temporary file") @@ -435,18 +435,17 @@ def _get_disk_file(self): if isinstance(ext, bytes): ext = ext.decode(sys.getfilesystemencoding()) - options['suffix'] = ext + options["suffix"] = ext if file_dir is not None: d = file_dir if isinstance(d, bytes): d = d.decode(sys.getfilesystemencoding()) - options['dir'] = d - options['delete'] = delete_tmp + options["dir"] = d + options["delete"] = delete_tmp # Create a temporary (named) file with the appropriate settings. - self.logger.info("Creating a temporary file with options: %r", - options) + self.logger.info("Creating a temporary file with options: %r", options) try: tmp_file = tempfile.NamedTemporaryFile(**options) except OSError: @@ -483,18 +482,18 @@ def on_data(self, data): # If the bytes written isn't the same as the length, just return. if bwritten != len(data): - self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, - len(data)) + self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data)) return bwritten # Keep track of how many bytes we've written. self._bytes_written += bwritten # If we're in-memory and are over our limit, we create a file. - if (self._in_memory and - self._config.get('MAX_MEMORY_FILE_SIZE') is not None and - (self._bytes_written > - self._config.get('MAX_MEMORY_FILE_SIZE'))): + if ( + self._in_memory + and self._config.get("MAX_MEMORY_FILE_SIZE") is not None + and (self._bytes_written > self._config.get("MAX_MEMORY_FILE_SIZE")) + ): self.logger.info("Flushing to disk") self.flush_to_disk() @@ -502,8 +501,7 @@ def on_data(self, data): return bwritten def on_end(self): - """This method is called whenever the Field is finalized. - """ + """This method is called whenever the Field is finalized.""" # Flush the underlying file object self._fileobj.flush() @@ -521,11 +519,7 @@ def close(self): self._fileobj.close() def __repr__(self): - return "{}(file_name={!r}, field_name={!r})".format( - self.__class__.__name__, - self.file_name, - self.field_name - ) + return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name) class BaseParser: @@ -548,6 +542,7 @@ class BaseParser: The callback is not passed a copy of the data, since copying severely hurts performance. """ + def __init__(self): self.logger = logging.getLogger(__name__) @@ -593,15 +588,15 @@ def set_callback(self, name, new_func): exist). """ if new_func is None: - self.callbacks.pop('on_' + name, None) + self.callbacks.pop("on_" + name, None) else: - self.callbacks['on_' + name] = new_func + self.callbacks["on_" + name] = new_func def close(self): - pass # pragma: no cover + pass # pragma: no cover def finalize(self): - pass # pragma: no cover + pass # pragma: no cover def __repr__(self): return "%s()" % self.__class__.__name__ @@ -634,14 +629,14 @@ class OctetStreamParser(BaseParser): :param max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. """ - def __init__(self, callbacks={}, max_size=float('inf')): + + def __init__(self, callbacks={}, max_size=float("inf")): super().__init__() self.callbacks = callbacks self._started = False if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % - max_size) + raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size self._current_size = 0 @@ -652,7 +647,7 @@ def write(self, data): :param data: a bytestring """ if not self._started: - self.callback('start') + self.callback("start") self._started = True # Truncate data length. @@ -660,22 +655,25 @@ def write(self, data): if (self._current_size + data_len) > self.max_size: # We truncate the length of data that we are to process. new_size = int(self.max_size - self._current_size) - self.logger.warning("Current size is %d (max %d), so truncating " - "data length from %d to %d", - self._current_size, self.max_size, data_len, - new_size) + self.logger.warning( + "Current size is %d (max %d), so truncating data length from %d to %d", + self._current_size, + self.max_size, + data_len, + new_size, + ) data_len = new_size # Increment size, then callback, in case there's an exception. self._current_size += data_len - self.callback('data', data, 0, data_len) + self.callback("data", data, 0, data_len) return data_len def finalize(self): """Finalize this parser, which signals to that we are finished parsing, and sends the on_end callback. """ - self.callback('end') + self.callback("end") def __repr__(self): return "%s()" % self.__class__.__name__ @@ -726,8 +724,8 @@ class QuerystringParser(BaseParser): :param max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. """ - def __init__(self, callbacks={}, strict_parsing=False, - max_size=float('inf')): + + def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")): super().__init__() self.state = STATE_BEFORE_FIELD self._found_sep = False @@ -736,8 +734,7 @@ def __init__(self, callbacks={}, strict_parsing=False, # Max-size stuff if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % - max_size) + raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size self._current_size = 0 @@ -759,10 +756,13 @@ def write(self, data): if (self._current_size + data_len) > self.max_size: # We truncate the length of data that we are to process. new_size = int(self.max_size - self._current_size) - self.logger.warning("Current size is %d (max %d), so truncating " - "data length from %d to %d", - self._current_size, self.max_size, data_len, - new_size) + self.logger.warning( + "Current size is %d (max %d), so truncating data length from %d to %d", + self._current_size, + self.max_size, + data_len, + new_size, + ) data_len = new_size l = 0 @@ -794,15 +794,11 @@ def _internal_write(self, data, length): if found_sep: # If we're parsing strictly, we disallow blank chunks. if strict_parsing: - e = QuerystringParseError( - "Skipping duplicate ampersand/semicolon at " - "%d" % i - ) + e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i) e.offset = i raise e else: - self.logger.debug("Skipping duplicate ampersand/" - "semicolon at %d", i) + self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i) else: # This case is when we're skipping the (first) # separator between fields, so we just set our flag @@ -812,7 +808,7 @@ def _internal_write(self, data, length): # Emit a field-start event, and go to that state. Also, # reset the "found_sep" flag, for the next time we get to # this state. - self.callback('field_start') + self.callback("field_start") i -= 1 state = STATE_FIELD_NAME found_sep = False @@ -820,21 +816,21 @@ def _internal_write(self, data, length): elif state == STATE_FIELD_NAME: # Try and find a separator - we ensure that, if we do, we only # look for the equal sign before it. - sep_pos = data.find(b'&', i) + sep_pos = data.find(b"&", i) if sep_pos == -1: - sep_pos = data.find(b';', i) + sep_pos = data.find(b";", i) # See if we can find an equals sign in the remaining data. If # so, we can immediately emit the field name and jump to the # data state. if sep_pos != -1: - equals_pos = data.find(b'=', i, sep_pos) + equals_pos = data.find(b"=", i, sep_pos) else: - equals_pos = data.find(b'=', i) + equals_pos = data.find(b"=", i) if equals_pos != -1: # Emit this name. - self.callback('field_name', data, i, equals_pos) + self.callback("field_name", data, i, equals_pos) # Jump i to this position. Note that it will then have 1 # added to it below, which means the next iteration of this @@ -849,47 +845,46 @@ def _internal_write(self, data, length): # end - there's no data callback at all (not even with # a blank value). if sep_pos != -1: - self.callback('field_name', data, i, sep_pos) - self.callback('field_end') + self.callback("field_name", data, i, sep_pos) + self.callback("field_end") i = sep_pos - 1 state = STATE_BEFORE_FIELD else: # Otherwise, no separator in this block, so the # rest of this chunk must be a name. - self.callback('field_name', data, i, length) + self.callback("field_name", data, i, length) i = length else: # We're parsing strictly. If we find a separator, # this is an error - we require an equals sign. if sep_pos != -1: - e = QuerystringParseError( + e = QuerystringParseError( "When strict_parsing is True, we require an " "equals sign in all field chunks. Did not " - "find one in the chunk that starts at %d" % - (i,) + "find one in the chunk that starts at %d" % (i,) ) e.offset = i raise e # No separator in the rest of this chunk, so it's just # a field name. - self.callback('field_name', data, i, length) + self.callback("field_name", data, i, length) i = length elif state == STATE_FIELD_DATA: # Try finding either an ampersand or a semicolon after this # position. - sep_pos = data.find(b'&', i) + sep_pos = data.find(b"&", i) if sep_pos == -1: - sep_pos = data.find(b';', i) + sep_pos = data.find(b";", i) # If we found it, callback this bit as data and then go back # to expecting to find a field. if sep_pos != -1: - self.callback('field_data', data, i, sep_pos) - self.callback('field_end') + self.callback("field_data", data, i, sep_pos) + self.callback("field_end") # Note that we go to the separator, which brings us to the # "before field" state. This allows us to properly emit @@ -900,10 +895,10 @@ def _internal_write(self, data, length): # Otherwise, emit the rest as data and finish. else: - self.callback('field_data', data, i, length) + self.callback("field_data", data, i, length) i = length - else: # pragma: no cover (error case) + else: # pragma: no cover (error case) msg = "Reached an unknown state %d at %d" % (state, i) self.logger.warning(msg) e = QuerystringParseError(msg) @@ -923,13 +918,12 @@ def finalize(self): """ # If we're currently in the middle of a field, we finish it. if self.state == STATE_FIELD_DATA: - self.callback('field_end') - self.callback('end') + self.callback("field_end") + self.callback("end") def __repr__(self): return "{}(strict_parsing={!r}, max_size={!r})".format( - self.__class__.__name__, - self.strict_parsing, self.max_size + self.__class__.__name__, self.strict_parsing, self.max_size ) @@ -988,17 +982,16 @@ class MultipartParser(BaseParser): i.e. unbounded. """ - def __init__(self, boundary, callbacks={}, max_size=float('inf')): + def __init__(self, boundary, callbacks={}, max_size=float("inf")): # Initialize parser state. super().__init__() - self.state = STATE_START + self.state = MultipartState.START self.index = self.flags = 0 self.callbacks = callbacks if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % - max_size) + raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size self._current_size = 0 @@ -1015,9 +1008,9 @@ def __init__(self, boundary, callbacks={}, max_size=float('inf')): # self.skip = tuple(skip) # Save our boundary. - if isinstance(boundary, str): # pragma: no cover - boundary = boundary.encode('latin-1') - self.boundary = b'\r\n--' + boundary + if isinstance(boundary, str): # pragma: no cover + boundary = boundary.encode("latin-1") + self.boundary = b"\r\n--" + boundary # Get a set of characters that belong to our boundary. self.boundary_chars = frozenset(self.boundary) @@ -1043,10 +1036,13 @@ def write(self, data): if (self._current_size + data_len) > self.max_size: # We truncate the length of data that we are to process. new_size = int(self.max_size - self._current_size) - self.logger.warning("Current size is %d (max %d), so truncating " - "data length from %d to %d", - self._current_size, self.max_size, data_len, - new_size) + self.logger.warning( + "Current size is %d (max %d), so truncating data length from %d to %d", + self._current_size, + self.max_size, + data_len, + new_size, + ) data_len = new_size l = 0 @@ -1104,7 +1100,7 @@ def data_callback(name, remaining=False): while i < length: c = data[i] - if state == STATE_START: + if state == MultipartState.START: # Skip leading newlines if c == CR or c == LF: i += 1 @@ -1116,10 +1112,10 @@ def data_callback(name, remaining=False): # Move to the next state, but decrement i so that we re-process # this character. - state = STATE_START_BOUNDARY + state = MultipartState.START_BOUNDARY i -= 1 - elif state == STATE_START_BOUNDARY: + elif state == MultipartState.START_BOUNDARY: # Check to ensure that the last 2 characters in our boundary # are CRLF. if index == len(boundary) - 2: @@ -1145,16 +1141,15 @@ def data_callback(name, remaining=False): index = 0 # Callback for the start of a part. - self.callback('part_begin') + self.callback("part_begin") # Move to the next character and state. - state = STATE_HEADER_FIELD_START + state = MultipartState.HEADER_FIELD_START else: # Check to ensure our boundary matches if c != boundary[index + 2]: - msg = "Did not find boundary character %r at index " \ - "%d" % (c, index + 2) + msg = "Did not find boundary character %r at index " "%d" % (c, index + 2) self.logger.warning(msg) e = MultipartParseError(msg) e.offset = i @@ -1163,25 +1158,25 @@ def data_callback(name, remaining=False): # Increment index into boundary and continue. index += 1 - elif state == STATE_HEADER_FIELD_START: + elif state == MultipartState.HEADER_FIELD_START: # Mark the start of a header field here, reset the index, and # continue parsing our header field. index = 0 # Set a mark of our header field. - set_mark('header_field') + set_mark("header_field") # Move to parsing header fields. - state = STATE_HEADER_FIELD + state = MultipartState.HEADER_FIELD i -= 1 - elif state == STATE_HEADER_FIELD: + elif state == MultipartState.HEADER_FIELD: # If we've reached a CR at the beginning of a header, it means # that we've reached the second of 2 newlines, and so there are # no more headers to parse. if c == CR: - delete_mark('header_field') - state = STATE_HEADERS_ALMOST_DONE + delete_mark("header_field") + state = MultipartState.HEADERS_ALMOST_DONE i += 1 continue @@ -1203,49 +1198,47 @@ def data_callback(name, remaining=False): raise e # Call our callback with the header field. - data_callback('header_field') + data_callback("header_field") # Move to parsing the header value. - state = STATE_HEADER_VALUE_START + state = MultipartState.HEADER_VALUE_START else: # Lower-case this character, and ensure that it is in fact # a valid letter. If not, it's an error. cl = lower_char(c) if cl < LOWER_A or cl > LOWER_Z: - msg = "Found non-alphanumeric character %r in " \ - "header at %d" % (c, i) + msg = "Found non-alphanumeric character %r in " "header at %d" % (c, i) self.logger.warning(msg) e = MultipartParseError(msg) e.offset = i raise e - elif state == STATE_HEADER_VALUE_START: + elif state == MultipartState.HEADER_VALUE_START: # Skip leading spaces. if c == SPACE: i += 1 continue # Mark the start of the header value. - set_mark('header_value') + set_mark("header_value") # Move to the header-value state, reprocessing this character. - state = STATE_HEADER_VALUE + state = MultipartState.HEADER_VALUE i -= 1 - elif state == STATE_HEADER_VALUE: + elif state == MultipartState.HEADER_VALUE: # If we've got a CR, we're nearly done our headers. Otherwise, # we do nothing and just move past this character. if c == CR: - data_callback('header_value') - self.callback('header_end') - state = STATE_HEADER_VALUE_ALMOST_DONE + data_callback("header_value") + self.callback("header_end") + state = MultipartState.HEADER_VALUE_ALMOST_DONE - elif state == STATE_HEADER_VALUE_ALMOST_DONE: + elif state == MultipartState.HEADER_VALUE_ALMOST_DONE: # The last character should be a LF. If not, it's an error. if c != LF: - msg = "Did not find LF character at end of header " \ - "(found %r)" % (c,) + msg = "Did not find LF character at end of header " "(found %r)" % (c,) self.logger.warning(msg) e = MultipartParseError(msg) e.offset = i @@ -1254,9 +1247,9 @@ def data_callback(name, remaining=False): # Move back to the start of another header. Note that if that # state detects ANOTHER newline, it'll trigger the end of our # headers. - state = STATE_HEADER_FIELD_START + state = MultipartState.HEADER_FIELD_START - elif state == STATE_HEADERS_ALMOST_DONE: + elif state == MultipartState.HEADERS_ALMOST_DONE: # We're almost done our headers. This is reached when we parse # a CR at the beginning of a header, so our next character # should be a LF, or it's an error. @@ -1267,18 +1260,18 @@ def data_callback(name, remaining=False): e.offset = i raise e - self.callback('headers_finished') - state = STATE_PART_DATA_START + self.callback("headers_finished") + state = MultipartState.PART_DATA_START - elif state == STATE_PART_DATA_START: + elif state == MultipartState.PART_DATA_START: # Mark the start of our part data. - set_mark('part_data') + set_mark("part_data") # Start processing part data, including this character. - state = STATE_PART_DATA + state = MultipartState.PART_DATA i -= 1 - elif state == STATE_PART_DATA: + elif state == MultipartState.PART_DATA: # We're processing our part data right now. During this, we # need to efficiently search for our boundary, since any data # on any number of lines can be a part of the current data. @@ -1320,7 +1313,7 @@ def data_callback(name, remaining=False): # If we found a match for our boundary, we send the # existing data. if index == 0: - data_callback('part_data') + data_callback("part_data") # The current character matches, so continue! index += 1 @@ -1356,23 +1349,23 @@ def data_callback(name, remaining=False): # We need a LF character next. if c == LF: # Unset the part boundary flag. - flags &= (~FLAG_PART_BOUNDARY) + flags &= ~FLAG_PART_BOUNDARY # Callback indicating that we've reached the end of # a part, and are starting a new one. - self.callback('part_end') - self.callback('part_begin') + self.callback("part_end") + self.callback("part_begin") # Move to parsing new headers. index = 0 - state = STATE_HEADER_FIELD_START + state = MultipartState.HEADER_FIELD_START i += 1 continue # We didn't find an LF character, so no match. Reset # our index and clear our flag. index = 0 - flags &= (~FLAG_PART_BOUNDARY) + flags &= ~FLAG_PART_BOUNDARY # Otherwise, if we're at the last boundary (i.e. we've # seen a hyphen already)... @@ -1381,9 +1374,9 @@ def data_callback(name, remaining=False): if c == HYPHEN: # Callback to end the current part, and then the # message. - self.callback('part_end') - self.callback('end') - state = STATE_END + self.callback("part_end") + self.callback("end") + state = MultipartState.END else: # No match, so reset index. index = 0 @@ -1400,24 +1393,24 @@ def data_callback(name, remaining=False): elif prev_index > 0: # Callback to write the saved data. lb_data = join_bytes(self.lookbehind) - self.callback('part_data', lb_data, 0, prev_index) + self.callback("part_data", lb_data, 0, prev_index) # Overwrite our previous index. prev_index = 0 # Re-set our mark for part data. - set_mark('part_data') + set_mark("part_data") # Re-consider the current character, since this could be # the start of the boundary itself. i -= 1 - elif state == STATE_END: + elif state == MultipartState.END: # Do nothing and just consume a byte in the end state. if c not in (CR, LF): self.logger.warning("Consuming a byte '0x%x' in the end state", c) - else: # pragma: no cover (error case) + else: # pragma: no cover (error case) # We got into a strange state somehow! Just stop processing. msg = "Reached an unknown state %d at %d" % (state, i) self.logger.warning(msg) @@ -1436,9 +1429,9 @@ def data_callback(name, remaining=False): # that we haven't yet reached the end of this 'thing'. So, by setting # the mark to 0, we cause any data callbacks that take place in future # calls to this function to start from the beginning of that buffer. - data_callback('header_field', True) - data_callback('header_value', True) - data_callback('part_data', True) + data_callback("header_field", True) + data_callback("header_value", True) + data_callback("part_data", True) # Save values to locals. self.state = state @@ -1456,7 +1449,7 @@ def finalize(self): are in the final state of the parser (i.e. the end of the multipart message is well-formed), and, if not, throw an error. """ - # TODO: verify that we're in the state STATE_END, otherwise throw an + # TODO: verify that we're in the state MultipartState.END, otherwise throw an # error or otherwise state that we're not finished parsing. pass @@ -1516,23 +1509,31 @@ class FormParser: default values. """ + #: This is the default configuration for our form parser. #: Note: all file sizes should be in bytes. DEFAULT_CONFIG = { - 'MAX_BODY_SIZE': float('inf'), - 'MAX_MEMORY_FILE_SIZE': 1 * 1024 * 1024, - 'UPLOAD_DIR': None, - 'UPLOAD_KEEP_FILENAME': False, - 'UPLOAD_KEEP_EXTENSIONS': False, - + "MAX_BODY_SIZE": float("inf"), + "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024, + "UPLOAD_DIR": None, + "UPLOAD_KEEP_FILENAME": False, + "UPLOAD_KEEP_EXTENSIONS": False, # Error on invalid Content-Transfer-Encoding? - 'UPLOAD_ERROR_ON_BAD_CTE': False, + "UPLOAD_ERROR_ON_BAD_CTE": False, } - def __init__(self, content_type, on_field, on_file, on_end=None, - boundary=None, file_name=None, FileClass=File, - FieldClass=Field, config={}): - + def __init__( + self, + content_type, + on_field, + on_file, + on_end=None, + boundary=None, + file_name=None, + FileClass=File, + FieldClass=Field, + config={}, + ): self.logger = logging.getLogger(__name__) # Save variables. @@ -1555,7 +1556,7 @@ def __init__(self, content_type, on_field, on_file, on_end=None, self.config.update(config) # Depending on the Content-Type, we instantiate the correct parser. - if content_type == 'application/octet-stream': + if content_type == "application/octet-stream": # Work around the lack of 'nonlocal' in Py2 class vars: f = None @@ -1577,19 +1578,12 @@ def on_end(): if self.on_end is not None: self.on_end() - callbacks = { - 'on_start': on_start, - 'on_data': on_data, - 'on_end': on_end, - } + callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end} # Instantiate an octet-stream parser - parser = OctetStreamParser(callbacks, - max_size=self.config['MAX_BODY_SIZE']) - - elif (content_type == 'application/x-www-form-urlencoded' or - content_type == 'application/x-url-encoded'): + parser = OctetStreamParser(callbacks, max_size=self.config["MAX_BODY_SIZE"]) + elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded": name_buffer = [] class vars: @@ -1603,7 +1597,7 @@ def on_field_name(data, start, end): def on_field_data(data, start, end): if vars.f is None: - vars.f = FieldClass(b''.join(name_buffer)) + vars.f = FieldClass(b"".join(name_buffer)) del name_buffer[:] vars.f.write(data[start:end]) @@ -1612,7 +1606,7 @@ def on_field_end(): if vars.f is None: # If we get here, it's because there was no field data. # We create a field, set it to None, and then continue. - vars.f = FieldClass(b''.join(name_buffer)) + vars.f = FieldClass(b"".join(name_buffer)) del name_buffer[:] vars.f.set_none() @@ -1626,20 +1620,17 @@ def on_end(): # Setup callbacks. callbacks = { - 'on_field_start': on_field_start, - 'on_field_name': on_field_name, - 'on_field_data': on_field_data, - 'on_field_end': on_field_end, - 'on_end': on_end, + "on_field_start": on_field_start, + "on_field_name": on_field_name, + "on_field_data": on_field_data, + "on_field_end": on_field_end, + "on_end": on_end, } # Instantiate parser. - parser = QuerystringParser( - callbacks=callbacks, - max_size=self.config['MAX_BODY_SIZE'] - ) + parser = QuerystringParser(callbacks=callbacks, max_size=self.config["MAX_BODY_SIZE"]) - elif content_type == 'multipart/form-data': + elif content_type == "multipart/form-data": if boundary is None: self.logger.error("No boundary given") raise FormParserError("No boundary given") @@ -1676,7 +1667,7 @@ def on_header_value(data, start, end): header_value.append(data[start:end]) def on_header_end(): - headers[b''.join(header_name)] = b''.join(header_value) + headers[b"".join(header_name)] = b"".join(header_value) del header_name[:] del header_value[:] @@ -1686,12 +1677,12 @@ def on_headers_finished(): # Parse the content-disposition header. # TODO: handle mixed case - content_disp = headers.get(b'Content-Disposition') + content_disp = headers.get(b"Content-Disposition") disp, options = parse_options_header(content_disp) # Get the field and filename. - field_name = options.get(b'name') - file_name = options.get(b'filename') + field_name = options.get(b"name") + file_name = options.get(b"filename") # TODO: check for errors # Create the proper class. @@ -1704,29 +1695,21 @@ def on_headers_finished(): # Parse the given Content-Transfer-Encoding to determine what # we need to do with the incoming data. # TODO: check that we properly handle 8bit / 7bit encoding. - transfer_encoding = headers.get(b'Content-Transfer-Encoding', - b'7bit') + transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") - if (transfer_encoding == b'binary' or - transfer_encoding == b'8bit' or - transfer_encoding == b'7bit'): + if transfer_encoding == b"binary" or transfer_encoding == b"8bit" or transfer_encoding == b"7bit": vars.writer = vars.f - elif transfer_encoding == b'base64': + elif transfer_encoding == b"base64": vars.writer = Base64Decoder(vars.f) - elif transfer_encoding == b'quoted-printable': + elif transfer_encoding == b"quoted-printable": vars.writer = QuotedPrintableDecoder(vars.f) else: - self.logger.warning("Unknown Content-Transfer-Encoding: " - "%r", transfer_encoding) - if self.config['UPLOAD_ERROR_ON_BAD_CTE']: - raise FormParserError( - 'Unknown Content-Transfer-Encoding "{}"'.format( - transfer_encoding - ) - ) + self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding) + if self.config["UPLOAD_ERROR_ON_BAD_CTE"]: + raise FormParserError('Unknown Content-Transfer-Encoding "{}"'.format(transfer_encoding)) else: # If we aren't erroring, then we just treat this as an # unencoded Content-Transfer-Encoding. @@ -1739,25 +1722,22 @@ def on_end(): # These are our callbacks for the parser. callbacks = { - 'on_part_begin': on_part_begin, - 'on_part_data': on_part_data, - 'on_part_end': on_part_end, - 'on_header_field': on_header_field, - 'on_header_value': on_header_value, - 'on_header_end': on_header_end, - 'on_headers_finished': on_headers_finished, - 'on_end': on_end, + "on_part_begin": on_part_begin, + "on_part_data": on_part_data, + "on_part_end": on_part_end, + "on_header_field": on_header_field, + "on_header_value": on_header_value, + "on_header_end": on_header_end, + "on_headers_finished": on_headers_finished, + "on_end": on_end, } # Instantiate a multipart parser. - parser = MultipartParser(boundary, callbacks, - max_size=self.config['MAX_BODY_SIZE']) + parser = MultipartParser(boundary, callbacks, max_size=self.config["MAX_BODY_SIZE"]) else: self.logger.warning("Unknown Content-Type: %r", content_type) - raise FormParserError("Unknown Content-Type: {}".format( - content_type - )) + raise FormParserError("Unknown Content-Type: {}".format(content_type)) self.parser = parser @@ -1773,24 +1753,19 @@ def write(self, data): def finalize(self): """Finalize the parser.""" - if self.parser is not None and hasattr(self.parser, 'finalize'): + if self.parser is not None and hasattr(self.parser, "finalize"): self.parser.finalize() def close(self): """Close the parser.""" - if self.parser is not None and hasattr(self.parser, 'close'): + if self.parser is not None and hasattr(self.parser, "close"): self.parser.close() def __repr__(self): - return "{}(content_type={!r}, parser={!r})".format( - self.__class__.__name__, - self.content_type, - self.parser, - ) + return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser) -def create_form_parser(headers, on_field, on_file, trust_x_headers=False, - config={}): +def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config={}): """This function is a helper function to aid in creating a FormParser instances. Given a dictionary-like headers object, it will determine the correct information needed, instantiate a FormParser with the @@ -1810,7 +1785,7 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, :param config: Configuration variables to pass to the FormParser. """ - content_type = headers.get('Content-Type') + content_type = headers.get("Content-Type") if content_type is None: logging.getLogger(__name__).warning("No Content-Type header given") raise ValueError("No Content-Type header given!") @@ -1818,28 +1793,22 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, # Boundaries are optional (the FormParser will raise if one is needed # but not given). content_type, params = parse_options_header(content_type) - boundary = params.get(b'boundary') + boundary = params.get(b"boundary") # We need content_type to be a string, not a bytes object. - content_type = content_type.decode('latin-1') + content_type = content_type.decode("latin-1") # File names are optional. - file_name = headers.get('X-File-Name') + file_name = headers.get("X-File-Name") # Instantiate a form parser. - form_parser = FormParser(content_type, - on_field, - on_file, - boundary=boundary, - file_name=file_name, - config=config) + form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config) # Return our parser. return form_parser -def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, - **kwargs): +def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, **kwargs): """This function is useful if you just want to parse a request body, without too much work. Pass it a dictionary-like object of the request's headers, and a file-like object for the input stream, along with two @@ -1864,11 +1833,11 @@ def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, # Read chunks of 100KiB and write to the parser, but never read more than # the given Content-Length, if any. - content_length = headers.get('Content-Length') + content_length = headers.get("Content-Length") if content_length is not None: content_length = int(content_length) else: - content_length = float('inf') + content_length = float("inf") bytes_read = 0 while True: diff --git a/pyproject.toml b/pyproject.toml index 085ab6a..5833d83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dev = [ "PyYAML==6.0.1", "invoke==2.2.0", "pytest-timeout==2.2.0", + "ruff==0.2.1", "hatch", ] @@ -58,3 +59,15 @@ packages = ["multipart"] [tool.hatch.build.targets.sdist] include = ["/multipart", "/tests"] + +[tool.ruff] +line-length = 120 +select = ["E", "F", "I", "FA"] +ignore = ["B904", "B028", "F841", "E741"] + +[tool.ruff.format] +skip-magic-trailing-comma = true + +[tool.ruff.lint.isort] +combine-as-imports = true +split-on-trailing-comma = false diff --git a/requirements.txt b/requirements.txt index 9622472..23baf78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ pluggy==1.4.0 py==1.11.0 pytest==8.0.0 PyYAML==6.0.1 +ruff==0.2.1 diff --git a/tests/compat.py b/tests/compat.py index 897188d..8b0ccae 100644 --- a/tests/compat.py +++ b/tests/compat.py @@ -1,8 +1,8 @@ +import functools import os import re import sys import types -import functools def ensure_in_path(path): @@ -10,7 +10,7 @@ def ensure_in_path(path): Ensure that a given path is in the sys.path array """ if not os.path.isdir(path): - raise RuntimeError('Tried to add nonexisting path') + raise RuntimeError("Tried to add nonexisting path") def _samefile(x, y): try: @@ -44,7 +44,9 @@ def _samefile(x, y): xfail = pytest.mark.xfail else: - slow_test = lambda x: x + + def slow_test(x): + return x def xfail(*args, **kwargs): if len(args) > 0 and isinstance(args[0], types.FunctionType): @@ -64,8 +66,8 @@ def parametrize(field_names, field_values): # Create a decorator that saves this list of field names and values on the # function for later parametrizing. def decorator(func): - func.__dict__['param_names'] = field_names - func.__dict__['param_values'] = field_values + func.__dict__["param_names"] = field_names + func.__dict__["param_values"] = field_values return func return decorator @@ -73,7 +75,7 @@ def decorator(func): # This is a metaclass that actually performs the parametrization. class ParametrizingMetaclass(type): - IDENTIFIER_RE = re.compile('[^A-Za-z0-9]') + IDENTIFIER_RE = re.compile("[^A-Za-z0-9]") def __new__(klass, name, bases, attrs): new_attrs = attrs.copy() @@ -82,8 +84,8 @@ def __new__(klass, name, bases, attrs): if not isinstance(attr, types.FunctionType): continue - param_names = attr.__dict__.pop('param_names', None) - param_values = attr.__dict__.pop('param_values', None) + param_names = attr.__dict__.pop("param_names", None) + param_values = attr.__dict__.pop("param_values", None) if param_names is None or param_values is None: continue @@ -92,9 +94,7 @@ def __new__(klass, name, bases, attrs): assert len(param_names) == len(values) # Get a repr of the values, and fix it to be a valid identifier - human = '_'.join( - [klass.IDENTIFIER_RE.sub('', repr(x)) for x in values] - ) + human = "_".join([klass.IDENTIFIER_RE.sub("", repr(x)) for x in values]) # Create a new name. # new_name = attr.__name__ + "_%d" % i @@ -128,6 +128,4 @@ def new_func(self): # This is a class decorator that actually applies the above metaclass. def parametrize_class(klass): - return ParametrizingMetaclass(klass.__name__, - klass.__bases__, - klass.__dict__) + return ParametrizingMetaclass(klass.__name__, klass.__bases__, klass.__dict__) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 5cfacf4..b9cba86 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,19 +1,30 @@ import os -import sys -import yaml import random +import sys import tempfile import unittest -from .compat import ( - parametrize, - parametrize_class, - slow_test, -) from io import BytesIO from unittest.mock import Mock -from multipart.multipart import * +import yaml + +from multipart.decoders import Base64Decoder, QuotedPrintableDecoder +from multipart.exceptions import DecodeError, FileError, FormParserError, MultipartParseError +from multipart.multipart import ( + BaseParser, + Field, + File, + FormParser, + MultipartParser, + OctetStreamParser, + QuerystringParseError, + QuerystringParser, + create_form_parser, + parse_form, + parse_options_header, +) +from .compat import parametrize, parametrize_class, slow_test # Get the current directory for our later test cases. curr_dir = os.path.abspath(os.path.dirname(__file__)) @@ -28,53 +39,53 @@ def force_bytes(val): class TestField(unittest.TestCase): def setUp(self): - self.f = Field('foo') + self.f = Field("foo") def test_name(self): - self.assertEqual(self.f.field_name, 'foo') + self.assertEqual(self.f.field_name, "foo") def test_data(self): - self.f.write(b'test123') - self.assertEqual(self.f.value, b'test123') + self.f.write(b"test123") + self.assertEqual(self.f.value, b"test123") def test_cache_expiration(self): - self.f.write(b'test') - self.assertEqual(self.f.value, b'test') - self.f.write(b'123') - self.assertEqual(self.f.value, b'test123') + self.f.write(b"test") + self.assertEqual(self.f.value, b"test") + self.f.write(b"123") + self.assertEqual(self.f.value, b"test123") def test_finalize(self): - self.f.write(b'test123') + self.f.write(b"test123") self.f.finalize() - self.assertEqual(self.f.value, b'test123') + self.assertEqual(self.f.value, b"test123") def test_close(self): - self.f.write(b'test123') + self.f.write(b"test123") self.f.close() - self.assertEqual(self.f.value, b'test123') + self.assertEqual(self.f.value, b"test123") def test_from_value(self): - f = Field.from_value(b'name', b'value') - self.assertEqual(f.field_name, b'name') - self.assertEqual(f.value, b'value') + f = Field.from_value(b"name", b"value") + self.assertEqual(f.field_name, b"name") + self.assertEqual(f.value, b"value") - f2 = Field.from_value(b'name', None) + f2 = Field.from_value(b"name", None) self.assertEqual(f2.value, None) def test_equality(self): - f1 = Field.from_value(b'name', b'value') - f2 = Field.from_value(b'name', b'value') + f1 = Field.from_value(b"name", b"value") + f2 = Field.from_value(b"name", b"value") self.assertEqual(f1, f2) def test_equality_with_other(self): - f = Field.from_value(b'foo', b'bar') - self.assertFalse(f == b'foo') - self.assertFalse(b'foo' == f) + f = Field.from_value(b"foo", b"bar") + self.assertFalse(f == b"foo") + self.assertFalse(b"foo" == f) def test_set_none(self): - f = Field(b'foo') - self.assertEqual(f.value, b'') + f = Field(b"foo") + self.assertEqual(f.value, b"") f.set_none() self.assertEqual(f.value, None) @@ -84,7 +95,7 @@ class TestFile(unittest.TestCase): def setUp(self): self.c = {} self.d = force_bytes(tempfile.mkdtemp()) - self.f = File(b'foo.txt', config=self.c) + self.f = File(b"foo.txt", config=self.c) def assert_data(self, data): f = self.f.file_object @@ -98,26 +109,26 @@ def assert_exists(self): self.assertTrue(os.path.exists(full_path)) def test_simple(self): - self.f.write(b'foobar') - self.assert_data(b'foobar') + self.f.write(b"foobar") + self.assert_data(b"foobar") def test_invalid_write(self): m = Mock() m.write.return_value = 5 self.f._fileobj = m - v = self.f.write(b'foobar') + v = self.f.write(b"foobar") self.assertEqual(v, 5) def test_file_fallback(self): - self.c['MAX_MEMORY_FILE_SIZE'] = 1 + self.c["MAX_MEMORY_FILE_SIZE"] = 1 - self.f.write(b'1') + self.f.write(b"1") self.assertTrue(self.f.in_memory) - self.assert_data(b'1') + self.assert_data(b"1") - self.f.write(b'123') + self.f.write(b"123") self.assertFalse(self.f.in_memory) - self.assert_data(b'123') + self.assert_data(b"123") # Test flushing too. old_obj = self.f.file_object @@ -126,23 +137,23 @@ def test_file_fallback(self): self.assertIs(self.f.file_object, old_obj) def test_file_fallback_with_data(self): - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["MAX_MEMORY_FILE_SIZE"] = 10 - self.f.write(b'1' * 10) + self.f.write(b"1" * 10) self.assertTrue(self.f.in_memory) - self.f.write(b'2' * 10) + self.f.write(b"2" * 10) self.assertFalse(self.f.in_memory) - self.assert_data(b'11111111112222222222') + self.assert_data(b"11111111112222222222") def test_file_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = self.d - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_DIR"] = self.d + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists @@ -151,135 +162,124 @@ def test_file_name(self): def test_file_full_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = self.d - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_DIR"] = self.d + self.c["UPLOAD_KEEP_FILENAME"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists - self.assertEqual(self.f.actual_file_name, b'foo') + self.assertEqual(self.f.actual_file_name, b"foo") self.assert_exists() def test_file_full_name_with_ext(self): - self.c['UPLOAD_DIR'] = self.d - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['UPLOAD_KEEP_EXTENSIONS'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_DIR"] = self.d + self.c["UPLOAD_KEEP_FILENAME"] = True + self.c["UPLOAD_KEEP_EXTENSIONS"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists - self.assertEqual(self.f.actual_file_name, b'foo.txt') - self.assert_exists() - - def test_file_full_name_with_ext(self): - self.c['UPLOAD_DIR'] = self.d - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['UPLOAD_KEEP_EXTENSIONS'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 - - # Write. - self.f.write(b'12345678901') - self.assertFalse(self.f.in_memory) - - # Assert that the file exists - self.assertEqual(self.f.actual_file_name, b'foo.txt') + self.assertEqual(self.f.actual_file_name, b"foo.txt") self.assert_exists() def test_no_dir_with_extension(self): - self.c['UPLOAD_KEEP_EXTENSIONS'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_KEEP_EXTENSIONS"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists ext = os.path.splitext(self.f.actual_file_name)[1] - self.assertEqual(ext, b'.txt') + self.assertEqual(ext, b".txt") self.assert_exists() def test_invalid_dir_with_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = force_bytes(os.path.join('/', 'tmp', 'notexisting')) - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 5 + self.c["UPLOAD_DIR"] = force_bytes(os.path.join("/", "tmp", "notexisting")) + self.c["UPLOAD_KEEP_FILENAME"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 5 # Write. with self.assertRaises(FileError): - self.f.write(b'1234567890') + self.f.write(b"1234567890") def test_invalid_dir_no_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = force_bytes(os.path.join('/', 'tmp', 'notexisting')) - self.c['UPLOAD_KEEP_FILENAME'] = False - self.c['MAX_MEMORY_FILE_SIZE'] = 5 + self.c["UPLOAD_DIR"] = force_bytes(os.path.join("/", "tmp", "notexisting")) + self.c["UPLOAD_KEEP_FILENAME"] = False + self.c["MAX_MEMORY_FILE_SIZE"] = 5 # Write. with self.assertRaises(FileError): - self.f.write(b'1234567890') + self.f.write(b"1234567890") # TODO: test uploading two files with the same name. class TestParseOptionsHeader(unittest.TestCase): def test_simple(self): - t, p = parse_options_header('application/json') - self.assertEqual(t, b'application/json') + t, p = parse_options_header("application/json") + self.assertEqual(t, b"application/json") self.assertEqual(p, {}) def test_blank(self): - t, p = parse_options_header('') - self.assertEqual(t, b'') + t, p = parse_options_header("") + self.assertEqual(t, b"") self.assertEqual(p, {}) def test_single_param(self): - t, p = parse_options_header('application/json;par=val') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'par': b'val'}) + t, p = parse_options_header("application/json;par=val") + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"par": b"val"}) def test_single_param_with_spaces(self): - t, p = parse_options_header(b'application/json; par=val') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'par': b'val'}) + t, p = parse_options_header(b"application/json; par=val") + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"par": b"val"}) def test_multiple_params(self): - t, p = parse_options_header(b'application/json;par=val;asdf=foo') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'par': b'val', b'asdf': b'foo'}) + t, p = parse_options_header(b"application/json;par=val;asdf=foo") + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"par": b"val", b"asdf": b"foo"}) def test_quoted_param(self): t, p = parse_options_header(b'application/json;param="quoted"') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'param': b'quoted'}) + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"param": b"quoted"}) def test_quoted_param_with_semicolon(self): t, p = parse_options_header(b'application/json;param="quoted;with;semicolons"') - self.assertEqual(p[b'param'], b'quoted;with;semicolons') + self.assertEqual(p[b"param"], b"quoted;with;semicolons") def test_quoted_param_with_escapes(self): t, p = parse_options_header(b'application/json;param="This \\" is \\" a \\" quote"') - self.assertEqual(p[b'param'], b'This " is " a " quote') + self.assertEqual(p[b"param"], b'This " is " a " quote') def test_handles_ie6_bug(self): t, p = parse_options_header(b'text/plain; filename="C:\\this\\is\\a\\path\\file.txt"') - self.assertEqual(p[b'filename'], b'file.txt') - + self.assertEqual(p[b"filename"], b"file.txt") + def test_redos_attack_header(self): - t, p = parse_options_header(b'application/x-www-form-urlencoded; !="\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\') + t, p = parse_options_header( + b'application/x-www-form-urlencoded; !="' + b"\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\" + ) # If vulnerable, this test wouldn't finish, the line above would hang - self.assertIn(b'"\\', p[b'!']) + self.assertIn(b'"\\', p[b"!"]) def test_handles_rfc_2231(self): - t, p = parse_options_header(b'text/plain; param*=us-ascii\'en-us\'encoded%20message') + t, p = parse_options_header(b"text/plain; param*=us-ascii'en-us'encoded%20message") - self.assertEqual(p[b'param'], b'encoded message') + self.assertEqual(p[b"param"], b"encoded message") class TestBaseParser(unittest.TestCase): @@ -290,25 +290,26 @@ def setUp(self): def test_callbacks(self): # The stupid list-ness is to get around lack of nonlocal on py2 l = [0] + def on_foo(): l[0] += 1 - self.b.set_callback('foo', on_foo) - self.b.callback('foo') + self.b.set_callback("foo", on_foo) + self.b.callback("foo") self.assertEqual(l[0], 1) - self.b.set_callback('foo', None) - self.b.callback('foo') + self.b.set_callback("foo", None) + self.b.callback("foo") self.assertEqual(l[0], 1) class TestQuerystringParser(unittest.TestCase): def assert_fields(self, *args, **kwargs): - if kwargs.pop('finalize', True): + if kwargs.pop("finalize", True): self.p.finalize() self.assertEqual(self.f, list(args)) - if kwargs.get('reset', True): + if kwargs.get("reset", True): self.f = [] def setUp(self): @@ -327,103 +328,80 @@ def on_field_data(data, start, end): data_buffer.append(data[start:end]) def on_field_end(): - self.f.append(( - b''.join(name_buffer), - b''.join(data_buffer) - )) + self.f.append((b"".join(name_buffer), b"".join(data_buffer))) del name_buffer[:] del data_buffer[:] - callbacks = { - 'on_field_name': on_field_name, - 'on_field_data': on_field_data, - 'on_field_end': on_field_end - } + callbacks = {"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end} self.p = QuerystringParser(callbacks) def test_simple_querystring(self): - self.p.write(b'foo=bar') + self.p.write(b"foo=bar") - self.assert_fields((b'foo', b'bar')) + self.assert_fields((b"foo", b"bar")) def test_querystring_blank_beginning(self): - self.p.write(b'&foo=bar') + self.p.write(b"&foo=bar") - self.assert_fields((b'foo', b'bar')) + self.assert_fields((b"foo", b"bar")) def test_querystring_blank_end(self): - self.p.write(b'foo=bar&') + self.p.write(b"foo=bar&") - self.assert_fields((b'foo', b'bar')) + self.assert_fields((b"foo", b"bar")) def test_multiple_querystring(self): - self.p.write(b'foo=bar&asdf=baz') + self.p.write(b"foo=bar&asdf=baz") - self.assert_fields( - (b'foo', b'bar'), - (b'asdf', b'baz') - ) + self.assert_fields((b"foo", b"bar"), (b"asdf", b"baz")) def test_streaming_simple(self): - self.p.write(b'foo=bar&') - self.assert_fields( - (b'foo', b'bar'), - finalize=False - ) + self.p.write(b"foo=bar&") + self.assert_fields((b"foo", b"bar"), finalize=False) - self.p.write(b'asdf=baz') - self.assert_fields( - (b'asdf', b'baz') - ) + self.p.write(b"asdf=baz") + self.assert_fields((b"asdf", b"baz")) def test_streaming_break(self): - self.p.write(b'foo=one') + self.p.write(b"foo=one") self.assert_fields(finalize=False) - self.p.write(b'two') + self.p.write(b"two") self.assert_fields(finalize=False) - self.p.write(b'three') + self.p.write(b"three") self.assert_fields(finalize=False) - self.p.write(b'&asd') - self.assert_fields( - (b'foo', b'onetwothree'), - finalize=False - ) + self.p.write(b"&asd") + self.assert_fields((b"foo", b"onetwothree"), finalize=False) - self.p.write(b'f=baz') - self.assert_fields( - (b'asdf', b'baz') - ) + self.p.write(b"f=baz") + self.assert_fields((b"asdf", b"baz")) def test_semicolon_separator(self): - self.p.write(b'foo=bar;asdf=baz') + self.p.write(b"foo=bar;asdf=baz") - self.assert_fields( - (b'foo', b'bar'), - (b'asdf', b'baz') - ) + self.assert_fields((b"foo", b"bar"), (b"asdf", b"baz")) def test_too_large_field(self): self.p.max_size = 15 # Note: len = 8 self.p.write(b"foo=bar&") - self.assert_fields((b'foo', b'bar'), finalize=False) + self.assert_fields((b"foo", b"bar"), finalize=False) # Note: len = 8, only 7 bytes processed - self.p.write(b'a=123456') - self.assert_fields((b'a', b'12345')) + self.p.write(b"a=123456") + self.assert_fields((b"a", b"12345")) def test_invalid_max_size(self): with self.assertRaises(ValueError): p = QuerystringParser(max_size=-100) def test_strict_parsing_pass(self): - data = b'foo=bar&another=asdf' + data = b"foo=bar&another=asdf" for first, last in split_all(data): self.reset() self.p.strict_parsing = True @@ -432,10 +410,10 @@ def test_strict_parsing_pass(self): self.p.write(first) self.p.write(last) - self.assert_fields((b'foo', b'bar'), (b'another', b'asdf')) + self.assert_fields((b"foo", b"bar"), (b"another", b"asdf")) def test_strict_parsing_fail_double_sep(self): - data = b'foo=bar&&another=asdf' + data = b"foo=bar&&another=asdf" for first, last in split_all(data): self.reset() self.p.strict_parsing = True @@ -452,7 +430,7 @@ def test_strict_parsing_fail_double_sep(self): self.assertEqual(cm.exception.offset, 8 - cnt) def test_double_sep(self): - data = b'foo=bar&&another=asdf' + data = b"foo=bar&&another=asdf" for first, last in split_all(data): print(f" {first!r} / {last!r} ") self.reset() @@ -461,23 +439,19 @@ def test_double_sep(self): cnt += self.p.write(first) cnt += self.p.write(last) - self.assert_fields((b'foo', b'bar'), (b'another', b'asdf')) + self.assert_fields((b"foo", b"bar"), (b"another", b"asdf")) def test_strict_parsing_fail_no_value(self): self.p.strict_parsing = True with self.assertRaises(QuerystringParseError) as cm: - self.p.write(b'foo=bar&blank&another=asdf') + self.p.write(b"foo=bar&blank&another=asdf") if cm is not None: self.assertEqual(cm.exception.offset, 8) def test_success_no_value(self): - self.p.write(b'foo=bar&blank&another=asdf') - self.assert_fields( - (b'foo', b'bar'), - (b'blank', b''), - (b'another', b'asdf') - ) + self.p.write(b"foo=bar&blank&another=asdf") + self.assert_fields((b"foo", b"bar"), (b"blank", b""), (b"another", b"asdf")) def test_repr(self): # Issue #29; verify we don't assert on repr() @@ -499,16 +473,12 @@ def on_data(data, start, end): def on_end(): self.finished += 1 - callbacks = { - 'on_start': on_start, - 'on_data': on_data, - 'on_end': on_end - } + callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end} self.p = OctetStreamParser(callbacks) def assert_data(self, data, finalize=True): - self.assertEqual(b''.join(self.d), data) + self.assertEqual(b"".join(self.d), data) self.d = [] def assert_started(self, val=True): @@ -528,9 +498,9 @@ def test_simple(self): self.assert_started(False) # Write something, it should then be started + have data - self.p.write(b'foobar') + self.p.write(b"foobar") self.assert_started() - self.assert_data(b'foobar') + self.assert_data(b"foobar") # Finalize, and check self.assert_finished(False) @@ -538,26 +508,26 @@ def test_simple(self): self.assert_finished() def test_multiple_chunks(self): - self.p.write(b'foo') - self.p.write(b'bar') - self.p.write(b'baz') + self.p.write(b"foo") + self.p.write(b"bar") + self.p.write(b"baz") self.p.finalize() - self.assert_data(b'foobarbaz') + self.assert_data(b"foobarbaz") self.assert_finished() def test_max_size(self): self.p.max_size = 5 - self.p.write(b'0123456789') + self.p.write(b"0123456789") self.p.finalize() - self.assert_data(b'01234') + self.assert_data(b"01234") self.assert_finished() def test_invalid_max_size(self): with self.assertRaises(ValueError): - q = OctetStreamParser(max_size='foo') + q = OctetStreamParser(max_size="foo") class TestBase64Decoder(unittest.TestCase): @@ -576,37 +546,37 @@ def assert_data(self, data, finalize=True): self.f.truncate() def test_simple(self): - self.d.write(b'Zm9vYmFy') - self.assert_data(b'foobar') + self.d.write(b"Zm9vYmFy") + self.assert_data(b"foobar") def test_bad(self): with self.assertRaises(DecodeError): - self.d.write(b'Zm9v!mFy') + self.d.write(b"Zm9v!mFy") def test_split_properly(self): - self.d.write(b'Zm9v') - self.d.write(b'YmFy') - self.assert_data(b'foobar') + self.d.write(b"Zm9v") + self.d.write(b"YmFy") + self.assert_data(b"foobar") def test_bad_split(self): - buff = b'Zm9v' + buff = b"Zm9v" for i in range(1, 4): first, second = buff[:i], buff[i:] self.setUp() self.d.write(first) self.d.write(second) - self.assert_data(b'foo') + self.assert_data(b"foo") def test_long_bad_split(self): - buff = b'Zm9vYmFy' + buff = b"Zm9vYmFy" for i in range(5, 8): first, second = buff[:i], buff[i:] self.setUp() self.d.write(first) self.d.write(second) - self.assert_data(b'foobar') + self.assert_data(b"foobar") def test_close_and_finalize(self): parser = Mock() @@ -619,7 +589,7 @@ def test_close_and_finalize(self): parser.close.assert_called_once_with() def test_bad_length(self): - self.d.write(b'Zm9vYmF') # missing ending 'y' + self.d.write(b"Zm9vYmF") # missing ending 'y' with self.assertRaises(DecodeError): self.d.finalize() @@ -640,35 +610,35 @@ def assert_data(self, data, finalize=True): self.f.truncate() def test_simple(self): - self.d.write(b'foobar') - self.assert_data(b'foobar') + self.d.write(b"foobar") + self.assert_data(b"foobar") def test_with_escape(self): - self.d.write(b'foo=3Dbar') - self.assert_data(b'foo=bar') + self.d.write(b"foo=3Dbar") + self.assert_data(b"foo=bar") def test_with_newline_escape(self): - self.d.write(b'foo=\r\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=\r\nbar") + self.assert_data(b"foobar") def test_with_only_newline_escape(self): - self.d.write(b'foo=\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=\nbar") + self.assert_data(b"foobar") def test_with_split_escape(self): - self.d.write(b'foo=3') - self.d.write(b'Dbar') - self.assert_data(b'foo=bar') + self.d.write(b"foo=3") + self.d.write(b"Dbar") + self.assert_data(b"foo=bar") def test_with_split_newline_escape_1(self): - self.d.write(b'foo=\r') - self.d.write(b'\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=\r") + self.d.write(b"\nbar") + self.assert_data(b"foobar") def test_with_split_newline_escape_2(self): - self.d.write(b'foo=') - self.d.write(b'\r\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=") + self.d.write(b"\r\nbar") + self.assert_data(b"foobar") def test_close_and_finalize(self): parser = Mock() @@ -684,23 +654,23 @@ def test_not_aligned(self): """ https://github.com/andrew-d/python-multipart/issues/6 """ - self.d.write(b'=3AX') - self.assert_data(b':X') + self.d.write(b"=3AX") + self.assert_data(b":X") # Additional offset tests - self.d.write(b'=3') - self.d.write(b'AX') - self.assert_data(b':X') + self.d.write(b"=3") + self.d.write(b"AX") + self.assert_data(b":X") - self.d.write(b'q=3AX') - self.assert_data(b'q:X') + self.d.write(b"q=3AX") + self.assert_data(b"q:X") # Load our list of HTTP test cases. -http_tests_dir = os.path.join(curr_dir, 'test_data', 'http') +http_tests_dir = os.path.join(curr_dir, "test_data", "http") # Read in all test cases and load them. -NON_PARAMETRIZED_TESTS = {'single_field_blocks'} +NON_PARAMETRIZED_TESTS = {"single_field_blocks"} http_tests = [] for f in os.listdir(http_tests_dir): # Only load the HTTP test cases. @@ -708,22 +678,18 @@ def test_not_aligned(self): if fname in NON_PARAMETRIZED_TESTS: continue - if ext == '.http': + if ext == ".http": # Get the YAML file and load it too. - yaml_file = os.path.join(http_tests_dir, fname + '.yaml') + yaml_file = os.path.join(http_tests_dir, fname + ".yaml") # Load both. - with open(os.path.join(http_tests_dir, f), 'rb') as f: + with open(os.path.join(http_tests_dir, f), "rb") as f: test_data = f.read() - with open(yaml_file, 'rb') as f: + with open(yaml_file, "rb") as f: yaml_data = yaml.safe_load(f) - http_tests.append({ - 'name': fname, - 'test': test_data, - 'result': yaml_data - }) + http_tests.append({"name": fname, "test": test_data, "result": yaml_data}) def split_all(val): @@ -754,8 +720,7 @@ def on_end(): self.ended = True # Get a form-parser instance. - self.f = FormParser('multipart/form-data', on_field, on_file, on_end, - boundary=boundary, config=config) + self.f = FormParser("multipart/form-data", on_field, on_file, on_end, boundary=boundary, config=config) def assert_file_data(self, f, data): o = f.file_object @@ -800,18 +765,18 @@ def assert_field(self, name, value): # Remove it for future iterations. self.fields.remove(found) - @parametrize('param', http_tests) + @parametrize("param", http_tests) def test_http(self, param): # Firstly, create our parser with the given boundary. - boundary = param['result']['boundary'] + boundary = param["result"]["boundary"] if isinstance(boundary, str): - boundary = boundary.encode('latin-1') + boundary = boundary.encode("latin-1") self.make(boundary) # Now, we feed the parser with data. exc = None try: - processed = self.f.write(param['test']) + processed = self.f.write(param["test"]) self.f.finalize() except MultipartParseError as e: processed = 0 @@ -823,29 +788,25 @@ def test_http(self, param): # print(repr(self.files)) # Do we expect an error? - if 'error' in param['result']['expected']: + if "error" in param["result"]["expected"]: self.assertIsNotNone(exc) - self.assertEqual(param['result']['expected']['error'], exc.offset) + self.assertEqual(param["result"]["expected"]["error"], exc.offset) return # No error! - self.assertEqual(processed, len(param['test'])) + self.assertEqual(processed, len(param["test"])) # Assert that the parser gave us the appropriate fields/files. - for e in param['result']['expected']: + for e in param["result"]["expected"]: # Get our type and name. - type = e['type'] - name = e['name'].encode('latin-1') + type = e["type"] + name = e["name"].encode("latin-1") - if type == 'field': - self.assert_field(name, e['data']) + if type == "field": + self.assert_field(name, e["data"]) - elif type == 'file': - self.assert_file( - name, - e['file_name'].encode('latin-1'), - e['data'] - ) + elif type == "file": + self.assert_file(name, e["file_name"].encode("latin-1"), e["data"]) else: assert False @@ -856,14 +817,14 @@ def test_random_splitting(self): through every possible split. """ # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # We split the file through all cases. for first, last in split_all(test_data): # Create form parser. - self.make('boundary') + self.make("boundary") # Feed with data in 2 chunks. i = 0 @@ -875,27 +836,27 @@ def test_random_splitting(self): self.assertEqual(i, len(test_data)) # Assert that our file and field are here. - self.assert_field(b'field', b'test1') - self.assert_file(b'file', b'file.txt', b'test2') + self.assert_field(b"field", b"test1") + self.assert_file(b"file", b"file.txt", b"test2") def test_feed_single_bytes(self): """ This test parses a simple multipart body 1 byte at a time. """ # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # Create form parser. - self.make('boundary') + self.make("boundary") # Write all bytes. # NOTE: Can't simply do `for b in test_data`, since that gives # an integer when iterating over a bytes object on Python 3. i = 0 for x in range(len(test_data)): - b = test_data[x:x + 1] + b = test_data[x : x + 1] i += self.f.write(b) self.f.finalize() @@ -904,24 +865,23 @@ def test_feed_single_bytes(self): self.assertEqual(i, len(test_data)) # Assert that our file and field are here. - self.assert_field(b'field', b'test1') - self.assert_file(b'file', b'file.txt', b'test2') + self.assert_field(b"field", b"test1") + self.assert_file(b"file", b"file.txt", b"test2") def test_feed_blocks(self): """ This test parses a simple multipart body 1 byte at a time. """ # Load test data. - test_file = 'single_field_blocks.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_blocks.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() for c in range(1, len(test_data) + 1): # Skip first `d` bytes - not interesting for d in range(c): - # Create form parser. - self.make('boundary') + self.make("boundary") # Skip i = 0 self.f.write(test_data[:d]) @@ -930,7 +890,7 @@ def test_feed_blocks(self): # Write a chunk to achieve condition # `i == data_length - 1` # in boundary search loop (multipatr.py:1302) - b = test_data[x:x + c] + b = test_data[x : x + c] i += self.f.write(b) self.f.finalize() @@ -939,8 +899,7 @@ def test_feed_blocks(self): self.assertEqual(i, len(test_data)) # Assert that our field is here. - self.assert_field(b'field', - b'0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ') + self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ") @slow_test def test_request_body_fuzz(self): @@ -953,8 +912,8 @@ def test_request_body_fuzz(self): - Randomly swapping two bytes """ # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() iterations = 1000 @@ -995,7 +954,7 @@ def test_request_body_fuzz(self): print(" " + msg) # Create form parser. - self.make('boundary') + self.make("boundary") # Feed with data, and ignore form parser exceptions. i = 0 @@ -1033,7 +992,7 @@ def test_request_body_fuzz_random_data(self): print(" Testing with %d random bytes..." % (data_size,)) # Create form parser. - self.make('boundary') + self.make("boundary") # Feed with data, and ignore form parser exceptions. i = 0 @@ -1054,40 +1013,44 @@ def test_request_body_fuzz_random_data(self): print("Exceptions: %d" % (exceptions,)) def test_bad_start_boundary(self): - self.make('boundary') - data = b'--boundary\rfoobar' + self.make("boundary") + data = b"--boundary\rfoobar" with self.assertRaises(MultipartParseError): self.f.write(data) - self.make('boundary') - data = b'--boundaryfoobar' + self.make("boundary") + data = b"--boundaryfoobar" with self.assertRaises(MultipartParseError): i = self.f.write(data) def test_octet_stream(self): files = [] + def on_file(f): files.append(f) + on_field = Mock() on_end = Mock() - f = FormParser('application/octet-stream', on_field, on_file, on_end=on_end, file_name=b'foo.txt') + f = FormParser("application/octet-stream", on_field, on_file, on_end=on_end, file_name=b"foo.txt") self.assertTrue(isinstance(f.parser, OctetStreamParser)) - f.write(b'test') - f.write(b'1234') + f.write(b"test") + f.write(b"1234") f.finalize() # Assert that we only received a single file, with the right data, and that we're done. self.assertFalse(on_field.called) self.assertEqual(len(files), 1) - self.assert_file_data(files[0], b'test1234') + self.assert_file_data(files[0], b"test1234") self.assertTrue(on_end.called) def test_querystring(self): fields = [] + def on_field(f): fields.append(f) + on_file = Mock() on_end = Mock() @@ -1098,8 +1061,8 @@ def simple_test(f): on_end.reset_mock() # Write test data. - f.write(b'foo=bar') - f.write(b'&test=asdf') + f.write(b"foo=bar") + f.write(b"&test=asdf") f.finalize() # Assert we only received 2 fields... @@ -1107,26 +1070,26 @@ def simple_test(f): self.assertEqual(len(fields), 2) # ...assert that we have the correct data... - self.assertEqual(fields[0].field_name, b'foo') - self.assertEqual(fields[0].value, b'bar') + self.assertEqual(fields[0].field_name, b"foo") + self.assertEqual(fields[0].value, b"bar") - self.assertEqual(fields[1].field_name, b'test') - self.assertEqual(fields[1].value, b'asdf') + self.assertEqual(fields[1].field_name, b"test") + self.assertEqual(fields[1].value, b"asdf") # ... and assert that we've finished. self.assertTrue(on_end.called) - f = FormParser('application/x-www-form-urlencoded', on_field, on_file, on_end=on_end) + f = FormParser("application/x-www-form-urlencoded", on_field, on_file, on_end=on_end) self.assertTrue(isinstance(f.parser, QuerystringParser)) simple_test(f) - f = FormParser('application/x-url-encoded', on_field, on_file, on_end=on_end) + f = FormParser("application/x-url-encoded", on_field, on_file, on_end=on_end) self.assertTrue(isinstance(f.parser, QuerystringParser)) simple_test(f) def test_close_methods(self): parser = Mock() - f = FormParser('application/x-url-encoded', None, None) + f = FormParser("application/x-url-encoded", None, None) f.parser = parser f.finalize() @@ -1138,69 +1101,76 @@ def test_close_methods(self): def test_bad_content_type(self): # We should raise a ValueError for a bad Content-Type with self.assertRaises(ValueError): - f = FormParser('application/bad', None, None) + f = FormParser("application/bad", None, None) def test_no_boundary_given(self): # We should raise a FormParserError when parsing a multipart message # without a boundary. with self.assertRaises(FormParserError): - f = FormParser('multipart/form-data', None, None) + f = FormParser("multipart/form-data", None, None) def test_bad_content_transfer_encoding(self): - data = b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\nContent-Type: text/plain\r\nContent-Transfer-Encoding: badstuff\r\n\r\nTest\r\n----boundary--\r\n' + data = ( + b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\n' + b"Content-Type: text/plain\r\n" + b"Content-Transfer-Encoding: badstuff\r\n\r\n" + b"Test\r\n----boundary--\r\n" + ) files = [] + def on_file(f): files.append(f) + on_field = Mock() on_end = Mock() # Test with erroring. - config = {'UPLOAD_ERROR_ON_BAD_CTE': True} - f = FormParser('multipart/form-data', on_field, on_file, - on_end=on_end, boundary='--boundary', config=config) + config = {"UPLOAD_ERROR_ON_BAD_CTE": True} + f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary", config=config) with self.assertRaises(FormParserError): f.write(data) f.finalize() # Test without erroring. - config = {'UPLOAD_ERROR_ON_BAD_CTE': False} - f = FormParser('multipart/form-data', on_field, on_file, - on_end=on_end, boundary='--boundary', config=config) + config = {"UPLOAD_ERROR_ON_BAD_CTE": False} + f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary", config=config) f.write(data) f.finalize() - self.assert_file_data(files[0], b'Test') + self.assert_file_data(files[0], b"Test") def test_handles_None_fields(self): fields = [] + def on_field(f): fields.append(f) + on_file = Mock() on_end = Mock() - f = FormParser('application/x-www-form-urlencoded', on_field, on_file, on_end=on_end) - f.write(b'foo=bar&another&baz=asdf') + f = FormParser("application/x-www-form-urlencoded", on_field, on_file, on_end=on_end) + f.write(b"foo=bar&another&baz=asdf") f.finalize() - self.assertEqual(fields[0].field_name, b'foo') - self.assertEqual(fields[0].value, b'bar') + self.assertEqual(fields[0].field_name, b"foo") + self.assertEqual(fields[0].value, b"bar") - self.assertEqual(fields[1].field_name, b'another') + self.assertEqual(fields[1].field_name, b"another") self.assertEqual(fields[1].value, None) - self.assertEqual(fields[2].field_name, b'baz') - self.assertEqual(fields[2].value, b'asdf') + self.assertEqual(fields[2].field_name, b"baz") + self.assertEqual(fields[2].value, b"asdf") def test_max_size_multipart(self): # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # Create form parser. - self.make('boundary') + self.make("boundary") # Set the maximum length that we can process to be halfway through the # given data. @@ -1214,14 +1184,14 @@ def test_max_size_multipart(self): def test_max_size_form_parser(self): # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # Create form parser setting the maximum length that we can process to # be halfway through the given data. size = len(test_data) / 2 - self.make('boundary', config={'MAX_BODY_SIZE': size}) + self.make("boundary", config={"MAX_BODY_SIZE": size}) i = self.f.write(test_data) self.f.finalize() @@ -1231,29 +1201,35 @@ def test_max_size_form_parser(self): def test_octet_stream_max_size(self): files = [] + def on_file(f): files.append(f) + on_field = Mock() on_end = Mock() - f = FormParser('application/octet-stream', on_field, on_file, - on_end=on_end, file_name=b'foo.txt', - config={'MAX_BODY_SIZE': 10}) + f = FormParser( + "application/octet-stream", + on_field, + on_file, + on_end=on_end, + file_name=b"foo.txt", + config={"MAX_BODY_SIZE": 10}, + ) - f.write(b'0123456789012345689') + f.write(b"0123456789012345689") f.finalize() - self.assert_file_data(files[0], b'0123456789') + self.assert_file_data(files[0], b"0123456789") def test_invalid_max_size_multipart(self): with self.assertRaises(ValueError): - q = MultipartParser(b'bound', max_size='foo') + q = MultipartParser(b"bound", max_size="foo") class TestHelperFunctions(unittest.TestCase): def test_create_form_parser(self): - r = create_form_parser({'Content-Type': 'application/octet-stream'}, - None, None) + r = create_form_parser({"Content-Type": "application/octet-stream"}, None, None) self.assertTrue(isinstance(r, FormParser)) def test_create_form_parser_error(self): @@ -1265,13 +1241,7 @@ def test_parse_form(self): on_field = Mock() on_file = Mock() - parse_form( - {'Content-Type': 'application/octet-stream', - }, - BytesIO(b'123456789012345'), - on_field, - on_file - ) + parse_form({"Content-Type": "application/octet-stream"}, BytesIO(b"123456789012345"), on_field, on_file) assert on_file.call_count == 1 @@ -1281,23 +1251,21 @@ def test_parse_form(self): def test_parse_form_content_length(self): files = [] + def on_file(file): files.append(file) parse_form( - {'Content-Type': 'application/octet-stream', - 'Content-Length': '10' - }, - BytesIO(b'123456789012345'), + {"Content-Type": "application/octet-stream", "Content-Length": "10"}, + BytesIO(b"123456789012345"), None, - on_file + on_file, ) self.assertEqual(len(files), 1) self.assertEqual(files[0].size, 10) - def suite(): suite = unittest.TestSuite() suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestFile))