Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Mount API #1362

Merged
merged 9 commits into from
Nov 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions docs/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -1040,12 +1040,13 @@ class HelloWorldTransport(httpcore.SyncHTTPTransport):
A mock transport that always returns a JSON "Hello, world!" response.
"""

def request(self, method, url, headers=None, stream=None, timeout=None):
def request(self, method, url, headers=None, stream=None, ext=None):
message = {"text": "Hello, world!"}
content = json.dumps(message).encode("utf-8")
stream = httpcore.PlainByteStream(content)
headers = [(b"content-type", b"application/json")]
return b"HTTP/1.1", 200, b"OK", headers, stream
ext = {"http_version": b"HTTP/1.1"}
return 200, headers, stream, ext
```

Which we can use in the same way:
Expand All @@ -1057,3 +1058,54 @@ Which we can use in the same way:
>>> response.json()
{"text": "Hello, world!"}
```

### Mounting transports

You can also mount transports against given schemes or domains, to control
which transport an outgoing request should be routed via, with [the same style
used for specifying proxy routing](#routing).

```python
import httpcore
import httpx

class HTTPSRedirectTransport(httpcore.SyncHTTPTransport):
"""
A transport that always redirects to HTTPS.
"""

def request(self, method, url, headers=None, stream=None, ext=None):
scheme, host, port, path = url
if port is None:
location = b"https://%s%s" % (host, path)
else:
location = b"https://%s:%d%s" % (host, port, path)
stream = httpcore.PlainByteStream(b"")
headers = [(b"location", location)]
ext = {"http_version": b"HTTP/1.1"}
return 303, headers, stream, ext


# A client where any `http` requests are always redirected to `https`
mounts = {'http://': HTTPSRedirectTransport()}
client = httpx.Client(mounts=mounts)
```

A couple of other sketches of how you might take advantage of mounted transports...

Mocking requests to a given domain:

```python
# All requests to "example.org" should be mocked out.
# Other requests occur as usual.
mounts = {"all://example.org": MockTransport()}
client = httpx.Client(mounts=mounts)
```

Adding support for custom schemes:

```python
# Support URLs like "file:///Users/sylvia_green/websites/new_client/index.html"
mounts = {"file://": FileSystemTransport()}
client = httpx.Client(mounts=mounts)
```
55 changes: 33 additions & 22 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
cookies: CookieTypes = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
base_url: URLTypes = "",
trust_env: bool = True,
):
Expand Down Expand Up @@ -561,11 +561,12 @@ def __init__(
cert: CertTypes = None,
http2: bool = False,
proxies: ProxiesTypes = None,
mounts: typing.Mapping[str, httpcore.SyncHTTPTransport] = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
limits: Limits = DEFAULT_LIMITS,
pool_limits: Limits = None,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
base_url: URLTypes = "",
transport: httpcore.SyncHTTPTransport = None,
app: typing.Callable = None,
Expand Down Expand Up @@ -611,7 +612,7 @@ def __init__(
app=app,
trust_env=trust_env,
)
self._proxies: typing.Dict[
self._mounts: typing.Dict[
URLPattern, typing.Optional[httpcore.SyncHTTPTransport]
] = {
URLPattern(key): None
Expand All @@ -626,7 +627,12 @@ def __init__(
)
for key, proxy in proxy_map.items()
}
self._proxies = dict(sorted(self._proxies.items()))
if mounts is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that the mounts overrides proxies? Like:

client = Client(..., proxies={"http://": SomeTransport()}, ..., mounts={"http://": SomeOtherTransport())
assert isinstance(client._mounts["http://"], SomeOtherTransport)

Do we need to outline it somewhere?

self._mounts.update(
{URLPattern(key): transport for key, transport in mounts.items()}
)

self._mounts = dict(sorted(self._mounts.items()))

def _init_transport(
self,
Expand Down Expand Up @@ -681,7 +687,7 @@ def _transport_for_url(self, url: URL) -> httpcore.SyncHTTPTransport:
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
for pattern, transport in self._proxies.items():
for pattern, transport in self._mounts.items():
Copy link
Member

@cdeler cdeler Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So only the thing, why we introduce self._mounts is a lookup over the dict items.

Then why it's a dictionary? May we have a list with tuples there?

Update: the only benefit we get from self._mounts presented by dict is the keys uniques.

if pattern.matches(url):
return self._transport if transport is None else transport

Expand Down Expand Up @@ -1109,17 +1115,17 @@ def close(self) -> None:
self._state = ClientState.CLOSED

self._transport.close()
for proxy in self._proxies.values():
if proxy is not None:
proxy.close()
for transport in self._mounts.values():
if transport is not None:
transport.close()

def __enter__(self: T) -> T:
self._state = ClientState.OPENED

self._transport.__enter__()
for proxy in self._proxies.values():
if proxy is not None:
proxy.__enter__()
for transport in self._mounts.values():
if transport is not None:
transport.__enter__()
return self

def __exit__(
Expand All @@ -1131,9 +1137,9 @@ def __exit__(
self._state = ClientState.CLOSED

self._transport.__exit__(exc_type, exc_value, traceback)
for proxy in self._proxies.values():
if proxy is not None:
proxy.__exit__(exc_type, exc_value, traceback)
for transport in self._mounts.values():
if transport is not None:
transport.__exit__(exc_type, exc_value, traceback)

def __del__(self) -> None:
self.close()
Expand Down Expand Up @@ -1198,11 +1204,12 @@ def __init__(
cert: CertTypes = None,
http2: bool = False,
proxies: ProxiesTypes = None,
mounts: typing.Mapping[str, httpcore.AsyncHTTPTransport] = None,
timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG,
limits: Limits = DEFAULT_LIMITS,
pool_limits: Limits = None,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None,
event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None,
base_url: URLTypes = "",
transport: httpcore.AsyncHTTPTransport = None,
app: typing.Callable = None,
Expand Down Expand Up @@ -1249,7 +1256,7 @@ def __init__(
trust_env=trust_env,
)

self._proxies: typing.Dict[
self._mounts: typing.Dict[
URLPattern, typing.Optional[httpcore.AsyncHTTPTransport]
] = {
URLPattern(key): None
Expand All @@ -1264,7 +1271,11 @@ def __init__(
)
for key, proxy in proxy_map.items()
}
self._proxies = dict(sorted(self._proxies.items()))
if mounts is not None:
self._mounts.update(
{URLPattern(key): transport for key, transport in mounts.items()}
)
self._mounts = dict(sorted(self._mounts.items()))

def _init_transport(
self,
Expand Down Expand Up @@ -1319,7 +1330,7 @@ def _transport_for_url(self, url: URL) -> httpcore.AsyncHTTPTransport:
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
for pattern, transport in self._proxies.items():
for pattern, transport in self._mounts.items():
if pattern.matches(url):
return self._transport if transport is None else transport

Expand Down Expand Up @@ -1499,7 +1510,7 @@ async def _send_single_request(
await timer.async_start()

with map_exceptions(HTTPCORE_EXC_MAP, request=request):
(status_code, headers, stream, ext,) = await transport.arequest(
(status_code, headers, stream, ext) = await transport.arequest(
request.method.encode(),
request.url.raw,
headers=request.headers.raw,
Expand Down Expand Up @@ -1750,15 +1761,15 @@ async def aclose(self) -> None:
self._state = ClientState.CLOSED

await self._transport.aclose()
for proxy in self._proxies.values():
for proxy in self._mounts.values():
if proxy is not None:
await proxy.aclose()

async def __aenter__(self: U) -> U:
self._state = ClientState.OPENED

await self._transport.__aenter__()
for proxy in self._proxies.values():
for proxy in self._mounts.values():
if proxy is not None:
await proxy.__aenter__()
return self
Expand All @@ -1772,7 +1783,7 @@ async def __aexit__(
self._state = ClientState.CLOSED

await self._transport.__aexit__(exc_type, exc_value, traceback)
for proxy in self._proxies.values():
for proxy in self._mounts.values():
if proxy is not None:
await proxy.__aexit__(exc_type, exc_value, traceback)

Expand Down
76 changes: 68 additions & 8 deletions tests/client/test_async_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from datetime import timedelta

import httpcore
Expand Down Expand Up @@ -188,15 +189,8 @@ async def __aexit__(self, *args):
await super().__aexit__(*args)
self.events.append("transport.__aexit__")

# Note that we're including 'proxies' here to *also* run through the
# proxy context management, although we can't easily test that at the
# moment, since we can't add proxies as transport instances.
#
# Once we have a more generalised Mount API we'll be able to remove this
# in favour of ensuring all mounts are context managed, which will
# also neccessarily include proxies.
transport = Transport()
async with httpx.AsyncClient(transport=transport, proxies="http://www.example.com"):
async with httpx.AsyncClient(transport=transport):
pass

assert transport.events == [
Expand All @@ -206,6 +200,47 @@ async def __aexit__(self, *args):
]


@pytest.mark.usefixtures("async_environment")
async def test_context_managed_transport_and_mount():
class Transport(httpcore.AsyncHTTPTransport):
def __init__(self, name: str):
self.name: str = name
self.events: typing.List[str] = []

async def aclose(self):
# The base implementation of httpcore.AsyncHTTPTransport just
# calls into `.aclose`, so simple transport cases can just override
# this method for any cleanup, where more complex cases
# might want to additionally override `__aenter__`/`__aexit__`.
self.events.append(f"{self.name}.aclose")

async def __aenter__(self):
await super().__aenter__()
self.events.append(f"{self.name}.__aenter__")

async def __aexit__(self, *args):
await super().__aexit__(*args)
self.events.append(f"{self.name}.__aexit__")

transport = Transport(name="transport")
mounted = Transport(name="mounted")
async with httpx.AsyncClient(
transport=transport, mounts={"http://www.example.org": mounted}
):
pass

assert transport.events == [
"transport.__aenter__",
"transport.aclose",
"transport.__aexit__",
]
assert mounted.events == [
"mounted.__aenter__",
"mounted.aclose",
"mounted.__aexit__",
]


def hello_world(request):
return httpx.Response(200, text="Hello, world!")

Expand Down Expand Up @@ -242,3 +277,28 @@ async def test_deleting_unclosed_async_client_causes_warning():
await client.get("http://example.com")
with pytest.warns(UserWarning):
del client


def unmounted(request: httpx.Request) -> httpx.Response:
data = {"app": "unmounted"}
return httpx.Response(200, json=data)


def mounted(request: httpx.Request) -> httpx.Response:
data = {"app": "mounted"}
return httpx.Response(200, json=data)


@pytest.mark.usefixtures("async_environment")
async def test_mounted_transport():
transport = MockTransport(unmounted)
mounts = {"custom://": MockTransport(mounted)}

async with httpx.AsyncClient(transport=transport, mounts=mounts) as client:
response = await client.get("https://www.example.com")
assert response.status_code == 200
assert response.json() == {"app": "unmounted"}

response = await client.get("custom://www.example.com")
assert response.status_code == 200
assert response.json() == {"app": "mounted"}
Loading