Skip to content

Commit

Permalink
Expand URL interface (#1601)
Browse files Browse the repository at this point in the history
* Expand URL interface

* Add URL query param manipulation methods
  • Loading branch information
tomchristie authored Apr 27, 2021
1 parent 2abb2f2 commit e67b0dd
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 18 deletions.
65 changes: 48 additions & 17 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class URL:
"""

def __init__(
self, url: typing.Union["URL", str, RawURL] = "", params: QueryParamTypes = None
self, url: typing.Union["URL", str, RawURL] = "", **kwargs: typing.Any
) -> None:
if isinstance(url, (str, tuple)):
if isinstance(url, tuple):
Expand Down Expand Up @@ -144,14 +144,8 @@ def __init__(
f"Invalid type for url. Expected str or httpx.URL, got {type(url)}: {url!r}"
)

# Add any query parameters, merging with any in the URL if needed.
if params:
if self._uri_reference.query:
url_params = QueryParams(self._uri_reference.query).merge(params)
query_string = str(url_params)
else:
query_string = str(QueryParams(params))
self._uri_reference = self._uri_reference.copy_with(query=query_string)
if kwargs:
self._uri_reference = self.copy_with(**kwargs)._uri_reference

@property
def scheme(self) -> str:
Expand Down Expand Up @@ -293,12 +287,27 @@ def path(self) -> str:
def query(self) -> bytes:
"""
The URL query string, as raw bytes, excluding the leading b"?".
Note that URL decoding can only be applied on URL query strings
at the point of decoding the individual parameter names/values.
This is neccessarily a bytewise interface, because we cannot
perform URL decoding of this representation until we've parsed
the keys and values into a QueryParams instance.
For example:
url = httpx.URL("https://example.com/?filter=some%20search%20terms")
assert url.query == b"filter=some%20search%20terms"
"""
query = self._uri_reference.query or ""
return query.encode("ascii")

@property
def params(self) -> "QueryParams":
"""
The URL query parameters, neatly parsed and packaged into an immutable
multidict representation.
"""
return QueryParams(self._uri_reference.query)

@property
def raw_path(self) -> bytes:
"""
Expand Down Expand Up @@ -382,6 +391,7 @@ def copy_with(self, **kwargs: typing.Any) -> "URL":
"query": bytes,
"raw_path": bytes,
"fragment": str,
"params": object,
}
for key, value in kwargs.items():
if key not in allowed:
Expand Down Expand Up @@ -434,12 +444,28 @@ def copy_with(self, **kwargs: typing.Any) -> "URL":
if kwargs.get("path") is not None:
kwargs["path"] = quote(kwargs["path"])

# Ensure query=<str> for rfc3986
if kwargs.get("query") is not None:
# Ensure query=<str> for rfc3986
kwargs["query"] = kwargs["query"].decode("ascii")

if "params" in kwargs:
params = kwargs.pop("params")
kwargs["query"] = None if not params else str(QueryParams(params))

return URL(self._uri_reference.copy_with(**kwargs).unsplit())

def copy_set_param(self, key: str, value: typing.Any = None) -> "URL":
return self.copy_with(params=self.params.set(key, value))

def copy_add_param(self, key: str, value: typing.Any = None) -> "URL":
return self.copy_with(params=self.params.add(key, value))

def copy_remove_param(self, key: str) -> "URL":
return self.copy_with(params=self.params.remove(key))

def copy_merge_params(self, params: QueryParamTypes) -> "URL":
return self.copy_with(params=self.params.merge(params))

def join(self, url: URLTypes) -> "URL":
"""
Return an absolute URL, using this URL as the base.
Expand Down Expand Up @@ -595,7 +621,7 @@ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
return self._dict[str(key)][0]
return default

def get_list(self, key: typing.Any) -> typing.List[str]:
def get_list(self, key: str) -> typing.List[str]:
"""
Get all values from the query param for a given key.
Expand All @@ -606,7 +632,7 @@ def get_list(self, key: typing.Any) -> typing.List[str]:
"""
return list(self._dict.get(str(key), []))

def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
def set(self, key: str, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting the value of a key.
Expand All @@ -621,7 +647,7 @@ def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
q._dict[str(key)] = [primitive_value_to_str(value)]
return q

def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
def add(self, key: str, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting or appending the value of a key.
Expand All @@ -636,7 +662,7 @@ def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
return q

def remove(self, key: typing.Any) -> "QueryParams":
def remove(self, key: str) -> "QueryParams":
"""
Return a new QueryParams instance, removing the value of a key.
Expand Down Expand Up @@ -681,6 +707,9 @@ def __iter__(self) -> typing.Iterator[typing.Any]:
def __len__(self) -> int:
return len(self._dict)

def __bool__(self) -> bool:
return bool(self._dict)

def __hash__(self) -> int:
return hash(str(self))

Expand Down Expand Up @@ -971,7 +1000,9 @@ def __init__(
self.method = method.decode("ascii").upper()
else:
self.method = method.upper()
self.url = URL(url, params=params)
self.url = URL(url)
if params is not None:
self.url = self.url.copy_merge_params(params=params)
self.headers = Headers(headers)
if cookies:
Cookies(cookies).set_cookie_header(self)
Expand Down
36 changes: 35 additions & 1 deletion tests/models/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,13 @@ def test_url_eq_str():
def test_url_params():
url = httpx.URL("https://example.org:123/path/to/somewhere", params={"a": "123"})
assert str(url) == "https://example.org:123/path/to/somewhere?a=123"
assert url.params == httpx.QueryParams({"a": "123"})

url = httpx.URL(
"https://example.org:123/path/to/somewhere?b=456", params={"a": "123"}
)
assert str(url) == "https://example.org:123/path/to/somewhere?b=456&a=123"
assert str(url) == "https://example.org:123/path/to/somewhere?a=123"
assert url.params == httpx.QueryParams({"a": "123"})


def test_url_join():
Expand All @@ -122,6 +124,38 @@ def test_url_join():
assert url.join("../../somewhere-else") == "https://example.org:123/somewhere-else"


def test_url_set_param_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_set_param("a", "456") == "https://example.org:123/?a=456"


def test_url_add_param_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_add_param("a", "456") == "https://example.org:123/?a=123&a=456"


def test_url_remove_param_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_remove_param("a") == "https://example.org:123/"


def test_url_merge_params_manipulation():
"""
Some basic URL query parameter manipulation.
"""
url = httpx.URL("https://example.org:123/?a=123")
assert url.copy_merge_params({"b": "456"}) == "https://example.org:123/?a=123&b=456"


def test_relative_url_join():
url = httpx.URL("/path/to/somewhere")
assert url.join("/somewhere-else") == "/somewhere-else"
Expand Down

0 comments on commit e67b0dd

Please sign in to comment.