Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop URL(allow_relative=bool) #1073

Merged
merged 5 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from ._utils import (
NetRCInfo,
enforce_http_url,
get_environment_proxies,
get_logger,
same_origin,
Expand All @@ -69,7 +70,7 @@ def __init__(
trust_env: bool = True,
):
if base_url is None:
self.base_url = URL("", allow_relative=True)
self.base_url = URL("")
else:
self.base_url = URL(base_url)

Expand Down Expand Up @@ -318,7 +319,7 @@ def _redirect_url(self, request: Request, response: Response) -> URL:
"""
location = response.headers["Location"]

url = URL(location, allow_relative=True)
url = URL(location)

# Check that we can handle the scheme
if url.scheme and url.scheme not in ("http", "https"):
Expand Down Expand Up @@ -539,6 +540,8 @@ def _transport_for_url(self, url: URL) -> httpcore.SyncHTTPTransport:
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
enforce_http_url(url)

if self._proxies and not should_not_be_proxied(url):
is_default_port = (url.scheme == "http" and url.port == 80) or (
url.scheme == "https" and url.port == 443
Expand Down Expand Up @@ -690,7 +693,6 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response:
"""
Sends a single request, without handling any redirections.
"""

transport = self._transport_for_url(request.url)

with map_exceptions(HTTPCORE_EXC_MAP, request=request):
Expand Down Expand Up @@ -1071,6 +1073,8 @@ def _transport_for_url(self, url: URL) -> httpcore.AsyncHTTPTransport:
Returns the transport instance that should be used for a given URL.
This will either be the standard connection pool, or a proxy.
"""
enforce_http_url(url)

if self._proxies and not should_not_be_proxied(url):
is_default_port = (url.scheme == "http" and url.port == 80) or (
url.scheme == "https" and url.port == 443
Expand Down Expand Up @@ -1130,9 +1134,6 @@ async def send(
allow_redirects: bool = True,
timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET,
) -> Response:
if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')

timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout)

auth = self._build_auth(request, auth)
Expand Down Expand Up @@ -1225,7 +1226,6 @@ async def _send_single_request(
"""
Sends a single request, without handling any redirections.
"""

transport = self._transport_for_url(request.url)

with map_exceptions(HTTPCORE_EXC_MAP, request=request):
Expand Down
22 changes: 3 additions & 19 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ._exceptions import (
CookieConflict,
HTTPStatusError,
InvalidURL,
NotRedirectResponse,
RequestNotRead,
ResponseClosed,
Expand Down Expand Up @@ -55,12 +54,7 @@


class URL:
def __init__(
self,
url: URLTypes,
allow_relative: bool = False,
params: QueryParamTypes = None,
) -> None:
def __init__(self, url: URLTypes, params: QueryParamTypes = None) -> None:
if isinstance(url, str):
self._uri_reference = rfc3986.api.iri_reference(url).encode()
else:
Expand All @@ -80,13 +74,6 @@ def __init__(
query_string = str(QueryParams(params))
self._uri_reference = self._uri_reference.copy_with(query=query_string)

# Enforce absolute URLs by default.
if not allow_relative:
if not self.scheme:
raise InvalidURL("No scheme included in URL.")
if not self.host:
raise InvalidURL("No host included in URL.")

@property
def scheme(self) -> str:
return self._uri_reference.scheme or ""
Expand Down Expand Up @@ -195,10 +182,7 @@ def copy_with(self, **kwargs: typing.Any) -> "URL":

kwargs["authority"] = authority

return URL(
self._uri_reference.copy_with(**kwargs).unsplit(),
allow_relative=self.is_relative_url,
)
return URL(self._uri_reference.copy_with(**kwargs).unsplit(),)

def join(self, relative_url: URLTypes) -> "URL":
"""
Expand All @@ -210,7 +194,7 @@ def join(self, relative_url: URLTypes) -> "URL":
# We drop any fragment portion, because RFC 3986 strictly
# treats URLs with a fragment portion as not being absolute URLs.
base_uri = self._uri_reference.copy_with(fragment=None)
relative_url = URL(relative_url, allow_relative=True)
relative_url = URL(relative_url)
return URL(relative_url._uri_reference.resolve_with(base_uri).unsplit())

def __hash__(self) -> int:
Expand Down
13 changes: 13 additions & 0 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from types import TracebackType
from urllib.request import getproxies

from ._exceptions import InvalidURL
from ._types import PrimitiveData

if typing.TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -260,6 +261,18 @@ def trace(message: str, *args: typing.Any, **kwargs: typing.Any) -> None:
return typing.cast(Logger, logger)


def enforce_http_url(url: "URL") -> None:
"""
Raise an appropriate InvalidURL for any non-HTTP URLs.
"""
if not url.scheme:
raise InvalidURL("No scheme included in URL.")
if not url.host:
raise InvalidURL("No host included in URL.")
if url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')


def same_origin(url: "URL", other: "URL") -> bool:
"""
Return 'True' if the given URLs share the same origin.
Expand Down
4 changes: 4 additions & 0 deletions tests/client/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ async def test_get_invalid_url(server):
async with httpx.AsyncClient() as client:
with pytest.raises(httpx.InvalidURL):
await client.get("invalid://example.org")
with pytest.raises(httpx.InvalidURL):
await client.get("://example.org")
with pytest.raises(httpx.InvalidURL):
await client.get("http://")
tomchristie marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.usefixtures("async_environment")
Expand Down
10 changes: 10 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def test_get(server):
assert response.elapsed > timedelta(0)


def test_get_invalid_url(server):
with httpx.Client() as client:
with pytest.raises(httpx.InvalidURL):
client.get("invalid://example.org")
with pytest.raises(httpx.InvalidURL):
client.get("://example.org")
with pytest.raises(httpx.InvalidURL):
client.get("http://")
Comment on lines +25 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Nit) Considering using a parametrized test style? I enjoy the documentation purpose of id=..., and that each case runs as its own test...

Then should we back-port this to test_async_client.py as well?

Suggested change
def test_get_invalid_url(server):
with httpx.Client() as client:
with pytest.raises(httpx.InvalidURL):
client.get("invalid://example.org")
with pytest.raises(httpx.InvalidURL):
client.get("://example.org")
with pytest.raises(httpx.InvalidURL):
client.get("http://")
@pytest.mark.parametrize(
"url",
[
pytest.param("invalid://example.org", id="scheme-not-http(s)"),
pytest.param("://example.org", id="no-scheme"),
pytest.param("http://", id="no-host"),
],
)
def test_get_invalid_url(server, url):
with httpx.Client() as client:
with pytest.raises(httpx.InvalidURL):
client.get(url)



def test_build_request(server):
url = server.url.copy_with(path="/echo_headers")
headers = {"Custom-header": "value"}
Expand Down
8 changes: 0 additions & 8 deletions tests/models/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,3 @@ def test_url():
assert request.url.scheme == "https"
assert request.url.port == 443
assert request.url.full_path == "/abc?foo=bar"


def test_invalid_urls():
with pytest.raises(httpx.InvalidURL):
httpx.Request("GET", "example.org")

with pytest.raises(httpx.InvalidURL):
httpx.Request("GET", "http:///foo")
5 changes: 1 addition & 4 deletions tests/models/test_url.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from httpx import URL, InvalidURL
from httpx import URL


@pytest.mark.parametrize(
Expand Down Expand Up @@ -116,9 +116,6 @@ def test_url_join_rfc3986():

url = URL("http://example.com/b/c/d;p?q")

with pytest.raises(InvalidURL):
assert url.join("g:h") == "g:h"

assert url.join("g") == "http://example.com/b/c/g"
assert url.join("./g") == "http://example.com/b/c/g"
assert url.join("g/") == "http://example.com/b/c/g/"
Expand Down