diff --git a/.gitignore b/.gitignore index 546cf0a..f7f7b71 100644 --- a/.gitignore +++ b/.gitignore @@ -28,7 +28,7 @@ lib64 pip-log.txt # Unit test / coverage reports -.coverage.* +.coverage* .tox nosetests.xml diff --git a/multipart/multipart.py b/multipart/multipart.py index 651bfc1..ac2648e 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -9,11 +9,38 @@ from enum import IntEnum from io import BytesIO from numbers import Number -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING from .decoders import Base64Decoder, QuotedPrintableDecoder from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError +if TYPE_CHECKING: # pragma: no cover + from typing import Callable, TypedDict + + class QuerystringCallbacks(TypedDict, total=False): + on_field_start: Callable[[], None] + on_field_name: Callable[[bytes, int, int], None] + on_field_data: Callable[[bytes, int, int], None] + on_field_end: Callable[[], None] + on_end: Callable[[], None] + + class OctetStreamCallbacks(TypedDict, total=False): + on_start: Callable[[], None] + on_data: Callable[[bytes, int, int], None] + on_end: Callable[[], None] + + class MultipartCallbacks(TypedDict, total=False): + on_part_begin: Callable[[], None] + on_part_data: Callable[[bytes, int, int], None] + on_part_end: Callable[[], None] + on_headers_begin: Callable[[], None] + on_header_field: Callable[[bytes, int, int], None] + on_header_value: Callable[[bytes, int, int], None] + on_header_end: Callable[[], None] + on_headers_finished: Callable[[], None] + on_end: Callable[[], None] + + # Unique missing object. _missing = object() @@ -86,7 +113,7 @@ def join_bytes(b): return bytes(list(b)) -def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, bytes]]: +def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]: """ Parses a Content-Type header into a value in the following format: (content_type, {parameters}) @@ -148,15 +175,15 @@ class Field: :param name: the name of the form field """ - def __init__(self, name): + def __init__(self, name: str): self._name = name - self._value = [] + self._value: list[bytes] = [] # We cache the joined version of _value for speed. self._cache = _missing @classmethod - def from_value(klass, name, value): + def from_value(cls, name: str, value: bytes | None) -> Field: """Create an instance of a :class:`Field`, and set the corresponding value - either None or an actual value. This method will also finalize the Field itself. @@ -166,7 +193,7 @@ def from_value(klass, name, value): None """ - f = klass(name) + f = cls(name) if value is None: f.set_none() else: @@ -174,14 +201,14 @@ def from_value(klass, name, value): f.finalize() return f - def write(self, data): + def write(self, data: bytes) -> int: """Write some data into the form field. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes) -> int: """This method is a callback that will be called whenever data is written to the Field. @@ -191,16 +218,16 @@ def on_data(self, data): self._cache = _missing return len(data) - def on_end(self): + def on_end(self) -> None: """This method is called whenever the Field is finalized.""" if self._cache is _missing: self._cache = b"".join(self._value) - def finalize(self): + def finalize(self) -> None: """Finalize the form field.""" self.on_end() - def close(self): + def close(self) -> None: """Close the Field object. This will free any underlying cache.""" # Free our value array. if self._cache is _missing: @@ -208,7 +235,7 @@ def close(self): del self._value - def set_none(self): + def set_none(self) -> None: """Some fields in a querystring can possibly have a value of None - for example, the string "foo&bar=&baz=asdf" will have a field with the name "foo" and value None, one with name "bar" and value "", and one @@ -218,7 +245,7 @@ def set_none(self): self._cache = None @property - def field_name(self): + def field_name(self) -> str: """This property returns the name of the field.""" return self._name @@ -230,13 +257,13 @@ def value(self): return self._cache - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Field): return self.field_name == other.field_name and self.value == other.value else: return NotImplemented - def __repr__(self): + def __repr__(self) -> str: if len(self.value) > 97: # We get the repr, and then insert three dots before the final # quote. @@ -553,7 +580,7 @@ class BaseParser: def __init__(self): self.logger = logging.getLogger(__name__) - def callback(self, name, data=None, start=None, end=None): + def callback(self, name: str, data=None, start=None, end=None): """This function calls a provided callback with some data. If the callback is not set, will do nothing. @@ -584,7 +611,7 @@ def callback(self, name, data=None, start=None, end=None): self.logger.debug("Calling %s with no data", name) func() - def set_callback(self, name, new_func): + def set_callback(self, name: str, new_func): """Update the function for a callback. Removes from the callbacks dict if new_func is None. @@ -637,7 +664,7 @@ class OctetStreamParser(BaseParser): i.e. unbounded. """ - def __init__(self, callbacks={}, max_size=float("inf")): + def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")): super().__init__() self.callbacks = callbacks self._started = False @@ -647,7 +674,7 @@ def __init__(self, callbacks={}, max_size=float("inf")): self.max_size = max_size self._current_size = 0 - def write(self, data): + def write(self, data: bytes): """Write some data to the parser, which will perform size verification, and then pass the data to the underlying callback. @@ -732,7 +759,9 @@ class QuerystringParser(BaseParser): i.e. unbounded. """ - def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")): + state: QuerystringState + + def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing=False, max_size=float("inf")): super().__init__() self.state = QuerystringState.BEFORE_FIELD self._found_sep = False @@ -748,7 +777,7 @@ def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")): # Should parsing be strict? self.strict_parsing = strict_parsing - def write(self, data): + def write(self, data: bytes): """Write some data to the parser, which will perform size verification, parse into either a field name or value, and then pass the corresponding data to the underlying callback. If an error is @@ -780,7 +809,7 @@ def write(self, data): return l - def _internal_write(self, data, length): + def _internal_write(self, data: bytes, length: int): state = self.state strict_parsing = self.strict_parsing found_sep = self._found_sep @@ -989,7 +1018,7 @@ class MultipartParser(BaseParser): i.e. unbounded. """ - def __init__(self, boundary, callbacks={}, max_size=float("inf")): + def __init__(self, boundary, callbacks: MultipartCallbacks = {}, max_size=float("inf")): # Initialize parser state. super().__init__() self.state = MultipartState.START diff --git a/tests/test_multipart.py b/tests/test_multipart.py index b9cba86..16db5b3 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -333,9 +333,9 @@ def on_field_end(): del name_buffer[:] del data_buffer[:] - callbacks = {"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end} - - self.p = QuerystringParser(callbacks) + self.p = QuerystringParser( + callbacks={"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end} + ) def test_simple_querystring(self): self.p.write(b"foo=bar") @@ -464,18 +464,16 @@ def setUp(self): self.started = 0 self.finished = 0 - def on_start(): + def on_start() -> None: self.started += 1 - def on_data(data, start, end): + def on_data(data: bytes, start: int, end: int) -> None: self.d.append(data[start:end]) - def on_end(): + def on_end() -> None: self.finished += 1 - callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end} - - self.p = OctetStreamParser(callbacks) + self.p = OctetStreamParser(callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end}) def assert_data(self, data, finalize=True): self.assertEqual(b"".join(self.d), data)