diff --git a/multipart/multipart.py b/multipart/multipart.py index 3275075..eac3ff8 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -146,10 +146,6 @@ def ord_char(c: int) -> int: return c -def join_bytes(b: bytes) -> bytes: - return bytes(list(b)) - - def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]: """Parses a Content-Type header into a value in the following format: (content_type, {parameters}).""" # Uses email.message.Message to parse the header as described in PEP 594. @@ -976,29 +972,11 @@ def __init__( # Setup marks. These are used to track the state of data received. self.marks: dict[str, int] = {} - # TODO: Actually use this rather than the dumb version we currently use - # # Precompute the skip table for the Boyer-Moore-Horspool algorithm. - # skip = [len(boundary) for x in range(256)] - # for i in range(len(boundary) - 1): - # skip[ord_char(boundary[i])] = len(boundary) - i - 1 - # - # # We use a tuple since it's a constant, and marginally faster. - # self.skip = tuple(skip) - # Save our boundary. if isinstance(boundary, str): # pragma: no cover boundary = boundary.encode("latin-1") self.boundary = b"\r\n--" + boundary - # Get a set of characters that belong to our boundary. - self.boundary_chars = frozenset(self.boundary) - - # We also create a lookbehind list. - # Note: the +8 is since we can have, at maximum, "\r\n--" + boundary + - # "--\r\n" at the final boundary, and the length of '\r\n--' and - # '--\r\n' is 8 bytes. - self.lookbehind = [NULL for _ in range(len(boundary) + 8)] - def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then parse the data into the appropriate location (e.g. header, @@ -1061,21 +1039,43 @@ def delete_mark(name: str, reset: bool = False) -> None: # end of the buffer, and reset the mark, instead of deleting it. This # is used at the end of the function to call our callbacks with any # remaining data in this chunk. - def data_callback(name: str, remaining: bool = False) -> None: + def data_callback(name: str, end_i: int, remaining: bool = False) -> None: marked_index = self.marks.get(name) if marked_index is None: return - # If we're getting remaining data, we ignore the current i value - # and just call with the remaining data. - if remaining: - self.callback(name, data, marked_index, length) - self.marks[name] = 0 - # Otherwise, we call it from the mark to the current byte we're # processing. + if end_i <= marked_index: + # There is no additional data to send. + pass + elif marked_index >= 0: + # We are emitting data from the local buffer. + self.callback(name, data, marked_index, end_i) + else: + # Some of the data comes from a partial boundary match. + # and requires look-behind. + # We need to use self.flags (and not flags) because we care about + # the state when we entered the loop. + lookbehind_len = -marked_index + if lookbehind_len <= len(boundary): + self.callback(name, boundary, 0, lookbehind_len) + elif self.flags & FLAG_PART_BOUNDARY: + lookback = boundary + b"\r\n" + self.callback(name, lookback, 0, lookbehind_len) + elif self.flags & FLAG_LAST_BOUNDARY: + lookback = boundary + b"--\r\n" + self.callback(name, lookback, 0, lookbehind_len) + else: # pragma: no cover (error case) + self.logger.warning("Look-back buffer error") + + if end_i > 0: + self.callback(name, data, 0, end_i) + # If we're getting remaining data, we have got all the data we + # can be certain is not a boundary, leaving only a partial boundary match. + if remaining: + self.marks[name] = end_i - length else: - self.callback(name, data, marked_index, i) self.marks.pop(name, None) # For each byte... @@ -1183,7 +1183,7 @@ def data_callback(name: str, remaining: bool = False) -> None: raise e # Call our callback with the header field. - data_callback("header_field") + data_callback("header_field", i) # Move to parsing the header value. state = MultipartState.HEADER_VALUE_START @@ -1212,7 +1212,7 @@ def data_callback(name: str, remaining: bool = False) -> None: # If we've got a CR, we're nearly done our headers. Otherwise, # we do nothing and just move past this character. if c == CR: - data_callback("header_value") + data_callback("header_value", i) self.callback("header_end") state = MultipartState.HEADER_VALUE_ALMOST_DONE @@ -1256,9 +1256,6 @@ def data_callback(name: str, remaining: bool = False) -> None: # We're processing our part data right now. During this, we # need to efficiently search for our boundary, since any data # on any number of lines can be a part of the current data. - # We use the Boyer-Moore-Horspool algorithm to efficiently - # search through the remainder of the buffer looking for our - # boundary. # Save the current value of our index. We use this in case we # find part of a boundary, but it doesn't match fully. @@ -1266,24 +1263,32 @@ def data_callback(name: str, remaining: bool = False) -> None: # Set up variables. boundary_length = len(boundary) - boundary_end = boundary_length - 1 data_length = length - boundary_chars = self.boundary_chars # If our index is 0, we're starting a new part, so start our # search. if index == 0: - # Search forward until we either hit the end of our buffer, - # or reach a character that's in our boundary. - i += boundary_end - while i < data_length - 1 and data[i] not in boundary_chars: - i += boundary_length - - # Reset i back the length of our boundary, which is the - # earliest possible location that could be our match (i.e. - # if we've just broken out of our loop since we saw the - # last character in our boundary) - i -= boundary_end + # The most common case is likely to be that the whole + # boundary is present in the buffer. + # Calling `find` is much faster than iterating here. + i0 = data.find(boundary, i, data_length) + if i0 >= 0: + # We matched the whole boundary string. + index = boundary_length - 1 + i = i0 + boundary_length - 1 + else: + # No match found for whole string. + # There may be a partial boundary at the end of the + # data, which the find will not match. + # Since the length should to be searched is limited to + # the boundary length, just perform a naive search. + i = max(i, data_length - boundary_length) + + # Search forward until we either hit the end of our buffer, + # or reach a potential start of the boundary. + while i < data_length - 1 and data[i] != boundary[0]: + i += 1 + c = data[i] # Now, we have a couple of cases here. If our index is before @@ -1291,11 +1296,6 @@ def data_callback(name: str, remaining: bool = False) -> None: if index < boundary_length: # If the character matches... if boundary[index] == c: - # If we found a match for our boundary, we send the - # existing data. - if index == 0: - data_callback("part_data") - # The current character matches, so continue! index += 1 else: @@ -1332,6 +1332,8 @@ def data_callback(name: str, remaining: bool = False) -> None: # Unset the part boundary flag. flags &= ~FLAG_PART_BOUNDARY + # We have identified a boundary, callback for any data before it. + data_callback("part_data", i - index) # Callback indicating that we've reached the end of # a part, and are starting a new one. self.callback("part_end") @@ -1353,6 +1355,8 @@ def data_callback(name: str, remaining: bool = False) -> None: elif flags & FLAG_LAST_BOUNDARY: # We need a second hyphen here. if c == HYPHEN: + # We have identified a boundary, callback for any data before it. + data_callback("part_data", i - index) # Callback to end the current part, and then the # message. self.callback("part_end") @@ -1362,26 +1366,14 @@ def data_callback(name: str, remaining: bool = False) -> None: # No match, so reset index. index = 0 - # If we have an index, we need to keep this byte for later, in - # case we can't match the full boundary. - if index > 0: - self.lookbehind[index - 1] = c - # Otherwise, our index is 0. If the previous index is not, it # means we reset something, and we need to take the data we # thought was part of our boundary and send it along as actual # data. - elif prev_index > 0: - # Callback to write the saved data. - lb_data = join_bytes(self.lookbehind) - self.callback("part_data", lb_data, 0, prev_index) - + if index == 0 and prev_index > 0: # Overwrite our previous index. prev_index = 0 - # Re-set our mark for part data. - set_mark("part_data") - # Re-consider the current character, since this could be # the start of the boundary itself. i -= 1 @@ -1410,9 +1402,9 @@ def data_callback(name: str, remaining: bool = False) -> None: # that we haven't yet reached the end of this 'thing'. So, by setting # the mark to 0, we cause any data callbacks that take place in future # calls to this function to start from the beginning of that buffer. - data_callback("header_field", True) - data_callback("header_value", True) - data_callback("part_data", True) + data_callback("header_field", length, True) + data_callback("header_value", length, True) + data_callback("part_data", length - index, True) # Save values to locals. self.state = state diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 3a814fb..2e22812 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -695,6 +695,14 @@ def test_not_aligned(self): http_tests.append({"name": fname, "test": test_data, "result": yaml_data}) +# Datasets used for single-byte writing test. +single_byte_tests = [ + "almost_match_boundary", + "almost_match_boundary_without_CR", + "almost_match_boundary_without_LF", + "almost_match_boundary_without_final_hyphen", + "single_field_single_file", +] def split_all(val): """ @@ -843,17 +851,19 @@ def test_random_splitting(self): self.assert_field(b"field", b"test1") self.assert_file(b"file", b"file.txt", b"test2") - def test_feed_single_bytes(self): + @parametrize("param", [ t for t in http_tests if t["name"] in single_byte_tests]) + def test_feed_single_bytes(self, param): """ - This test parses a simple multipart body 1 byte at a time. + This test parses multipart bodies 1 byte at a time. """ # Load test data. - test_file = "single_field_single_file.http" + test_file = param["name"] + ".http" + boundary = param["result"]["boundary"] with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # Create form parser. - self.make("boundary") + self.make(boundary) # Write all bytes. # NOTE: Can't simply do `for b in test_data`, since that gives @@ -868,9 +878,20 @@ def test_feed_single_bytes(self): # Assert we processed everything. self.assertEqual(i, len(test_data)) - # Assert that our file and field are here. - self.assert_field(b"field", b"test1") - self.assert_file(b"file", b"file.txt", b"test2") + # Assert that the parser gave us the appropriate fields/files. + for e in param["result"]["expected"]: + # Get our type and name. + type = e["type"] + name = e["name"].encode("latin-1") + + if type == "field": + self.assert_field(name, e["data"]) + + elif type == "file": + self.assert_file(name, e["file_name"].encode("latin-1"), e["data"]) + + else: + assert False def test_feed_blocks(self): """