Skip to content

Commit

Permalink
Fix race when using HTTP proxy.
Browse files Browse the repository at this point in the history
After `start_tls` we need to manually call `connection_made` on the
protocol to tell it about the new underlying transport. This gives a
brief window where the protocol can receive data without having a
transport to write to, causing issues with the APNS connections where it
assumes it can write once it starts reading data.

We fix this by wrapping the protocol in a buffer that simply buffers
incoming data until `connection_made` is called.
  • Loading branch information
erikjohnston committed Mar 19, 2021
1 parent 1d9d415 commit 7b6ce0c
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 2 deletions.
50 changes: 48 additions & 2 deletions sygnal/helper/proxy/proxy_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from ssl import Purpose, SSLContext, create_default_context
from typing import Callable, Optional, Tuple, Union

import attr

from sygnal.exceptions import ProxyConnectError
from sygnal.helper.proxy import decompose_http_proxy_url

Expand Down Expand Up @@ -146,19 +148,25 @@ async def switch_over_when_ready(self) -> Tuple[BaseTransport, Protocol]:
# unreachable
raise RuntimeError("Left over bytes should not occur with TLS")

# There is a race where the `new_protocol` may get given data before
# we manage to call `connection_made` on it, which can lead to
# exceptions if the protocol then tries to write to the transport
# that is has been given yet.
buffered_protocol = _BufferedWrapperProtocol(new_protocol)

# be careful not to use the `transport` ever again after passing it
# to start_tls — we overwrite our variable with the TLS-wrapped
# transport to avoid that!
transport = await self._event_loop.start_tls(
self._transport,
new_protocol,
buffered_protocol,
self._sslcontext,
server_hostname=self._target_hostport[0],
)

# start_tls does NOT call connection_made on new_protocol, so we
# must do it ourselves
new_protocol.connection_made(transport)
buffered_protocol.connection_made(transport)
else:
# no wrapping required for non-TLS
transport = self._transport
Expand All @@ -171,6 +179,8 @@ async def switch_over_when_ready(self) -> Tuple[BaseTransport, Protocol]:
# pass over dangling bytes if applicable
new_protocol.data_received(left_over_bytes)

logger.debug("Finished switching protocol")

return transport, new_protocol

def data_received(self, data: bytes) -> None:
Expand Down Expand Up @@ -332,3 +342,39 @@ def __getattr__(self, item):
We use this to delegate other method calls to the real EventLoop.
"""
return getattr(self._wrapped_loop, item)


@attr.s(slots=True, auto_attribs=True)
class _BufferedWrapperProtocol(Protocol):
"""Wraps a protocol to buffer any incoming data received before
`connection_made` is called.
"""

_protocol: Protocol
_connected: bool = False
_buffer: bytearray = attr.Factory(bytearray)

def connection_made(self, transport: BaseTransport):
self._connected = True
self._protocol.connection_made(transport)
if self._buffer:
self._protocol.data_received(self._buffer)
self._buffer = bytearray()

def connection_lost(self, exc: Optional[Exception]):
self._protocol.connection_lost(exc)

def pause_writing(self):
self._protocol.pause_writing()

def resume_writing(self):
self._protocol.resume_writing()

def data_received(self, data: bytes):
if self._connected:
self._protocol.data_received(data)
else:
self._buffer.extend(data)

def eof_received(self):
return self._protocol.eof_received()
38 changes: 38 additions & 0 deletions tests/asyncio_test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def call_later(
):
self.call_at(self._time + delay, callback, *args, context=context)

# We're mean to return a canceller, but can cheat and return a no-op one
# instead.
class _Canceller:
def cancel(self):
pass

return _Canceller()

def call_at(
self,
when: float,
Expand Down Expand Up @@ -114,6 +122,10 @@ def __init__(self):
# Whether this transport was closed
self.closed = False

# We need to explicitly mark that this connection allows start tls,
# otherwise `loop.start_tls` will raise an exception.
self._start_tls_compatible = True

def reset_mock(self) -> None:
self.buffer = b""
self.eofed = False
Expand Down Expand Up @@ -189,3 +201,29 @@ def write(self, data: bytes) -> None:
self.transport.write(data)
else:
self._to_transmit += data


class EchoProtocol(Protocol):
"""A protocol that immediately echoes all data it receives"""

def __init__(self):
self._to_transmit = b""
self.received_bytes = b""
self.transport = None

def data_received(self, data: bytes) -> None:
self.received_bytes += data
assert self.transport
self.transport.write(data)

def connection_made(self, transport: transports.BaseTransport) -> None:
assert isinstance(transport, Transport)
self.transport = transport
if self._to_transmit:
transport.write(self._to_transmit)

def write(self, data: bytes) -> None:
if self.transport:
self.transport.write(data)
else:
self._to_transmit += data
153 changes: 153 additions & 0 deletions tests/test_httpproxy_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import ssl
from asyncio import AbstractEventLoop, BaseTransport, Protocol, Task
from typing import Optional, Tuple, cast

Expand All @@ -21,10 +22,16 @@

from tests import testutils
from tests.asyncio_test_helpers import (
EchoProtocol,
MockProtocol,
MockTransport,
TimelessEventLoopWrapper,
)
from tests.twisted_test_helpers import (
create_test_cert_file,
get_test_ca_cert_file,
get_test_key_file,
)


class AsyncioHttpProxyTest(testutils.TestCase):
Expand Down Expand Up @@ -191,3 +198,149 @@ def test_connect_failure(self):
# check our protocol did not receive anything, because it was an HTTP-
# level error, not actually a connection to our target.
self.assertEqual(fake_protocol.received_bytes, b"")


class AsyncioHttpProxyTLSTest(testutils.TestCase):
"""Test that using a HTTPS proxy works.
This is a bit convoluted to try and test that we don't hit a race where the
new client protocol can receive data before `connection_made` is called,
which can cause problems if it tries to write to the connection that it
hasn't been given yet.
"""

def config_setup(self, config):
super().config_setup(config)
config["apps"]["com.example.spqr"] = {
"type": "tests.test_pushgateway_api_v1.TestPushkin"
}
self.base_loop = asyncio.new_event_loop()
augmented_loop = TimelessEventLoopWrapper(self.base_loop) # type: ignore
asyncio.set_event_loop(cast(AbstractEventLoop, augmented_loop))

self.loop = augmented_loop

self.proxy_context = ssl.create_default_context()
self.proxy_context.load_verify_locations(get_test_ca_cert_file())
self.proxy_context.set_ciphers("DEFAULT")

def make_fake_proxy(
self,
host: str,
port: int,
proxy_credentials: Optional[Tuple[str, str]],
) -> Tuple[MockProtocol, MockTransport, "Task[Tuple[BaseTransport, Protocol]]"]:
# Task[Tuple[MockTransport, MockProtocol]]

# make a fake proxy
fake_proxy = MockTransport()

# We connect with an echo protocol to test that we can always write when
# we receive data.
fake_protocol = EchoProtocol()

# create a HTTP CONNECT proxy client protocol
http_connect_protocol = HttpConnectProtocol(
target_hostport=(host, port),
proxy_credentials=proxy_credentials,
protocol_factory=lambda: fake_protocol,
sslcontext=self.proxy_context,
loop=None,
)
switch_over_task = self.loop.create_task(
http_connect_protocol.switch_over_when_ready()
)
# check the task is not somehow already marked as done before we even
# receive anything.
self.assertFalse(switch_over_task.done())
# connect the proxy client to the proxy
fake_proxy.set_protocol(http_connect_protocol)
http_connect_protocol.connection_made(fake_proxy)
return fake_protocol, fake_proxy, switch_over_task

def test_connect_no_credentials(self):
"""
Tests the proxy connection procedure when there is no basic auth.
"""
host = "example.org"
port = 443
proxy_credentials = None
fake_protocol, fake_proxy, switch_over_task = self.make_fake_proxy(
host, port, proxy_credentials
)

# Check that the proxy got the proper CONNECT request.
self.assertEqual(fake_proxy.buffer, b"CONNECT example.org:443 HTTP/1.0\r\n\r\n")
# Reset the proxy mock
fake_proxy.reset_mock()

# pretend we got a happy response
fake_proxy.pretend_to_receive(b"HTTP/1.0 200 Connection Established\r\n\r\n")

# Since we're talking TLS we need to create a server TLS connection that
# we can use to talk to each other.
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(
create_test_cert_file([b"DNS:example.org"]), keyfile=get_test_key_file()
)
context.set_ciphers("DEFAULT")

# Note that we have to use a different event loop wrapper here as we
# want that server side setup to finish before the client side setup, so
# that we can trigger any races.
server_loop = TimelessEventLoopWrapper(self.base_loop) # type: ignore
server_transport = MockTransport()
proxy_ft = server_loop.create_task(
server_loop.start_tls(
server_transport,
MockProtocol(),
context,
server_hostname=host,
server_side=True,
)
)

# Advance event loop because we have to let coroutines be executed
self.loop.advance(1.0)
server_loop.advance(1.0)

# We manually copy the bytes between the fake_proxy transport and our
# created TLS transport. We do this for each step in the TLS handshake.

# Client -> Server
server_transport.pretend_to_receive(fake_proxy.buffer)
fake_proxy.buffer = b""

# Server -> Client
fake_proxy.pretend_to_receive(server_transport.buffer)
server_transport.buffer = b""

# Client -> Server
server_transport.pretend_to_receive(fake_proxy.buffer)
fake_proxy.buffer = b""

# We *only* advance the server side loop so that we can send data before
# the client has called `connection_made` on the new protocol.
server_loop.advance(0.1)

# Server -> Client application data.
server_plain_transport = proxy_ft.result()
server_plain_transport.write(b"begin beep boop\r\n\r\n~~ :) ~~")
fake_proxy.pretend_to_receive(server_transport.buffer)
server_transport.buffer = b""

self.loop.advance(1.0)

# *now* we should have switched over from the HTTP CONNECT protocol
# to the user protocol (in our case, a MockProtocol).
self.assertTrue(switch_over_task.done())

transport, protocol = switch_over_task.result()

# check it was our protocol that was returned
self.assertIs(protocol, fake_protocol)

# check our protocol received exactly the bytes meant for it
self.assertEqual(
fake_protocol.received_bytes, b"begin beep boop\r\n\r\n~~ :) ~~"
)

0 comments on commit 7b6ce0c

Please sign in to comment.