Skip to content

Commit

Permalink
Switch event classes to dataclasses
Browse files Browse the repository at this point in the history
The _EventBundle used previously dynamically creates the Event
classes, which makes typing hard. As the Event classes are now stable
the dynamic creation isn't required, and as h11 supports Python3.6+
dataclasses can be used instead.

This now makes the Events frozen (and almost imutable) which better
matches the intended API. In turn it requires the `object.__setattr__`
usage and the alteration to the
`_clean_up_response_headers_for_sending` method.

This change also improves the performance a little, using the
benchmark,

Before: 6.7k requests/sec
After: 6.9k requests/sec

Notes:

The test response-header changes are required as the previous version
would mutate the response object.

The init for the Data event is required as slots and defaults aren't
possible with dataclasses.
  • Loading branch information
pgjones committed May 16, 2021
1 parent 0f719f0 commit 5d958b5
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 143 deletions.
12 changes: 8 additions & 4 deletions h11/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def send_with_data_passthrough(self, event):
raise LocalProtocolError("Can't send data when our state is ERROR")
try:
if type(event) is Response:
self._clean_up_response_headers_for_sending(event)
event = self._clean_up_response_headers_for_sending(event)
# We want to call _process_event before calling the writer,
# because if someone tries to do something invalid then this will
# give a sensible error message, while our writers all just assume
Expand Down Expand Up @@ -528,8 +528,7 @@ def send_failed(self):
#
# This function's *only* responsibility is making sure headers are set up
# right -- everything downstream just looks at the headers. There are no
# side channels. It mutates the response event in-place (but not the
# response.headers list object).
# side channels.
def _clean_up_response_headers_for_sending(self, response):
assert type(response) is Response

Expand Down Expand Up @@ -582,4 +581,9 @@ def _clean_up_response_headers_for_sending(self, response):
connection.add(b"close")
headers = set_comma_header(headers, b"connection", sorted(connection))

response.headers = headers
return Response(
headers=headers,
status_code=response.status_code,
http_version=response.http_version,
reason=response.reason,
)
223 changes: 144 additions & 79 deletions h11/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
# Don't subclass these. Stuff will break.

import re
from abc import ABC
from dataclasses import dataclass, field
from typing import Any, cast, Dict, List, Tuple, Union

from . import _headers
from ._abnf import request_target
from ._headers import Headers, normalize_and_validate
from ._util import bytesify, LocalProtocolError, validate

# Everything in __all__ gets re-exported as part of the h11 public API.
__all__ = [
"Event",
"Request",
"InformationalResponse",
"Response",
Expand All @@ -24,72 +28,16 @@
request_target_re = re.compile(request_target.encode("ascii"))


class _EventBundle:
_fields = []
_defaults = {}

def __init__(self, **kwargs):
_parsed = kwargs.pop("_parsed", False)
allowed = set(self._fields)
for kwarg in kwargs:
if kwarg not in allowed:
raise TypeError(
"unrecognized kwarg {} for {}".format(
kwarg, self.__class__.__name__
)
)
required = allowed.difference(self._defaults)
for field in required:
if field not in kwargs:
raise TypeError(
"missing required kwarg {} for {}".format(
field, self.__class__.__name__
)
)
self.__dict__.update(self._defaults)
self.__dict__.update(kwargs)

# Special handling for some fields

if "headers" in self.__dict__:
self.headers = _headers.normalize_and_validate(
self.headers, _parsed=_parsed
)

if not _parsed:
for field in ["method", "target", "http_version", "reason"]:
if field in self.__dict__:
self.__dict__[field] = bytesify(self.__dict__[field])

if "status_code" in self.__dict__:
if not isinstance(self.status_code, int):
raise LocalProtocolError("status code must be integer")
# Because IntEnum objects are instances of int, but aren't
# duck-compatible (sigh), see gh-72.
self.status_code = int(self.status_code)

self._validate()

def _validate(self):
pass

def __repr__(self):
name = self.__class__.__name__
kwarg_strs = [
"{}={}".format(field, self.__dict__[field]) for field in self._fields
]
kwarg_str = ", ".join(kwarg_strs)
return "{}({})".format(name, kwarg_str)

# Useful for tests
def __eq__(self, other):
return self.__class__ == other.__class__ and self.__dict__ == other.__dict__
class Event(ABC):
"""
Base class for h11 events.
"""

# This is an unhashable type.
__hash__ = None
__slots__ = ()


class Request(_EventBundle):
@dataclass(init=False, frozen=True)
class Request(Event):
"""The beginning of an HTTP request.
Fields:
Expand Down Expand Up @@ -123,10 +71,38 @@ class Request(_EventBundle):
"""

_fields = ["method", "target", "headers", "http_version"]
_defaults = {"http_version": b"1.1"}
__slots__ = ("method", "headers", "target", "http_version")

method: bytes
headers: Headers
target: bytes
http_version: bytes

def __init__(
self,
*,
method: Union[bytes, str],
headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]],
target: Union[bytes, str],
http_version: Union[bytes, str] = b"1.1",
_parsed: bool = False,
) -> None:
super().__init__()
if isinstance(headers, Headers):
object.__setattr__(self, "headers", headers)
else:
object.__setattr__(
self, "headers", normalize_and_validate(headers, _parsed=_parsed)
)
if not _parsed:
object.__setattr__(self, "method", bytesify(method))
object.__setattr__(self, "target", bytesify(target))
object.__setattr__(self, "http_version", bytesify(http_version))
else:
object.__setattr__(self, "method", method)
object.__setattr__(self, "target", target)
object.__setattr__(self, "http_version", http_version)

def _validate(self):
# "A server MUST respond with a 400 (Bad Request) status code to any
# HTTP/1.1 request message that lacks a Host header field and to any
# request message that contains more than one Host header field or a
Expand All @@ -143,12 +119,58 @@ def _validate(self):

validate(request_target_re, self.target, "Illegal target characters")

# This is an unhashable type.
__hash__ = None # type: ignore


@dataclass(init=False, frozen=True)
class _ResponseBase(Event):
__slots__ = ("headers", "http_version", "reason", "status_code")

headers: Headers
http_version: bytes
reason: bytes
status_code: int

def __init__(
self,
*,
headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]],
status_code: int,
http_version: Union[bytes, str] = b"1.1",
reason: Union[bytes, str] = b"",
_parsed: bool = False,
) -> None:
super().__init__()
if isinstance(headers, Headers):
object.__setattr__(self, "headers", headers)
else:
object.__setattr__(
self, "headers", normalize_and_validate(headers, _parsed=_parsed)
)
if not _parsed:
object.__setattr__(self, "reason", bytesify(reason))
object.__setattr__(self, "http_version", bytesify(http_version))
if not isinstance(status_code, int):
raise LocalProtocolError("status code must be integer")
# Because IntEnum objects are instances of int, but aren't
# duck-compatible (sigh), see gh-72.
object.__setattr__(self, "status_code", int(status_code))
else:
object.__setattr__(self, "reason", reason)
object.__setattr__(self, "http_version", http_version)
object.__setattr__(self, "status_code", status_code)

self.__post_init__()

def __post_init__(self) -> None:
pass

class _ResponseBase(_EventBundle):
_fields = ["status_code", "headers", "http_version", "reason"]
_defaults = {"http_version": b"1.1", "reason": b""}
# This is an unhashable type.
__hash__ = None # type: ignore


@dataclass(init=False, frozen=True)
class InformationalResponse(_ResponseBase):
"""An HTTP informational response.
Expand Down Expand Up @@ -179,14 +201,18 @@ class InformationalResponse(_ResponseBase):
"""

def _validate(self):
def __post_init__(self) -> None:
if not (100 <= self.status_code < 200):
raise LocalProtocolError(
"InformationalResponse status_code should be in range "
"[100, 200), not {}".format(self.status_code)
)

# This is an unhashable type.
__hash__ = None # type: ignore


@dataclass(init=False, frozen=True)
class Response(_ResponseBase):
"""The beginning of an HTTP response.
Expand Down Expand Up @@ -216,16 +242,20 @@ class Response(_ResponseBase):
"""

def _validate(self):
def __post_init__(self) -> None:
if not (200 <= self.status_code < 600):
raise LocalProtocolError(
"Response status_code should be in range [200, 600), not {}".format(
self.status_code
)
)

# This is an unhashable type.
__hash__ = None # type: ignore


class Data(_EventBundle):
@dataclass(init=False, frozen=True)
class Data(Event):
"""Part of an HTTP message body.
Fields:
Expand Down Expand Up @@ -258,16 +288,30 @@ class Data(_EventBundle):
"""

_fields = ["data", "chunk_start", "chunk_end"]
_defaults = {"chunk_start": False, "chunk_end": False}
__slots__ = ("data", "chunk_start", "chunk_end")

data: bytes
chunk_start: bool
chunk_end: bool

def __init__(
self, data: bytes, chunk_start: bool = False, chunk_end: bool = False
) -> None:
object.__setattr__(self, "data", data)
object.__setattr__(self, "chunk_start", chunk_start)
object.__setattr__(self, "chunk_end", chunk_end)

# This is an unhashable type.
__hash__ = None # type: ignore


# XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that
# are forbidden to be sent in a trailer, since processing them as if they were
# present in the header section might bypass external security filters."
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part
# Unfortunately, the list of forbidden fields is long and vague :-/
class EndOfMessage(_EventBundle):
@dataclass(init=False, frozen=True)
class EndOfMessage(Event):
"""The end of an HTTP message.
Fields:
Expand All @@ -284,11 +328,32 @@ class EndOfMessage(_EventBundle):
"""

_fields = ["headers"]
_defaults = {"headers": []}
__slots__ = ("headers",)

headers: Headers

def __init__(
self,
*,
headers: Union[
Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]], None
] = None,
_parsed: bool = False,
) -> None:
super().__init__()
if headers is None:
headers = Headers([])
elif not isinstance(headers, Headers):
headers = normalize_and_validate(headers, _parsed=_parsed)

object.__setattr__(self, "headers", headers)

# This is an unhashable type.
__hash__ = None # type: ignore


class ConnectionClosed(_EventBundle):
@dataclass(frozen=True)
class ConnectionClosed(Event):
"""This event indicates that the sender has closed their outgoing
connection.
Expand Down
10 changes: 6 additions & 4 deletions h11/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def normalize_data_events(in_events):
out_events = []
for event in in_events:
if type(event) is Data:
event.data = bytes(event.data)
event.chunk_start = False
event.chunk_end = False
event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False)
if out_events and type(out_events[-1]) is type(event) is Data:
out_events[-1].data += event.data
out_events[-1] = Data(
data=out_events[-1].data + event.data,
chunk_start=out_events[-1].chunk_start,
chunk_end=out_events[-1].chunk_end,
)
else:
out_events.append(event)
return out_events
Expand Down
Loading

0 comments on commit 5d958b5

Please sign in to comment.