Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent multidict methods #1089

Merged
merged 5 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json as jsonlib
import typing
import urllib.request
import warnings
from collections.abc import MutableMapping
from http.cookiejar import Cookie, CookieJar
from urllib.parse import parse_qsl, urlencode
Expand Down Expand Up @@ -240,26 +241,40 @@ def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None:
self._list = [(str(k), str_query_param(v)) for k, v in items]
self._dict = {str(k): str_query_param(v) for k, v in items}

def getlist(self, key: typing.Any) -> typing.List[str]:
return [item_value for item_key, item_value in self._list if item_key == key]

def keys(self) -> typing.KeysView:
return self._dict.keys()

def values(self) -> typing.ValuesView:
return self._dict.values()

def items(self) -> typing.ItemsView:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.
"""
return self._dict.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
"""
Return all items in the query params. Allow duplicate keys to occur.
"""
return list(self._list)

def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
"""
Get a value from the query param for a given key. If the key occurs
more than once, then only the first value is returned.
"""
if key in self._dict:
return self._dict[key]
return default

def get_list(self, key: typing.Any) -> typing.List[str]:
"""
Get all values from the query param for a given key.
"""
return [item_value for item_key, item_value in self._list if item_key == key]

def update(self, params: QueryParamTypes = None) -> None:
if not params:
return
Expand Down Expand Up @@ -315,6 +330,13 @@ def __repr__(self) -> str:
query_string = str(self)
return f"{class_name}({query_string!r})"

def getlist(self, key: typing.Any) -> typing.List[str]:
message = (
"QueryParams.getlist() is pending deprecation. Use QueryParams.get_list()"
)
warnings.warn(message, PendingDeprecationWarning)
return self.get_list(key)


class Headers(typing.MutableMapping[str, str]):
"""
Expand All @@ -336,6 +358,14 @@ def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None:
(normalize_header_key(k, encoding), normalize_header_value(v, encoding))
for k, v in headers
]

self._dict = {} # type: typing.Dict[bytes, bytes]
for key, value in self._list:
if key in self._dict:
self._dict[key] = self._dict[key] + b", " + value
else:
self._dict[key] = value

self._encoding = encoding

@property
Expand Down Expand Up @@ -376,26 +406,47 @@ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self._list

def keys(self) -> typing.List[str]: # type: ignore
return [key.decode(self.encoding) for key, value in self._list]
return [key.decode(self.encoding) for key in self._dict.keys()]

def values(self) -> typing.List[str]: # type: ignore
return [value.decode(self.encoding) for key, value in self._list]
return [value.decode(self.encoding) for value in self._dict.values()]

def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
"""
Return a list of `(key, value)` pairs of headers. Concatenate headers
into a single comma seperated value when a key occurs multiple times.
"""
return [
(key.decode(self.encoding), value.decode(self.encoding))
for key, value in self._dict.items()
]

def multi_items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
"""
Return a list of `(key, value)` pairs of headers. Allow multiple
occurences of the same key without concatenating into a single
comma seperated value.
"""
return [
(key.decode(self.encoding), value.decode(self.encoding))
for key, value in self._list
]

def get(self, key: str, default: typing.Any = None) -> typing.Any:
"""
Return a header value. If multiple occurences of the header occur
then concatenate them together with commas.
"""
try:
return self[key]
except KeyError:
return default

def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]:
def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]:
"""
Return multiple header values.
Return a list of all header values for a given key.
If `split_commas=True` is passed, then any comma seperated header
values are split into multiple return strings.
"""
get_header_key = key.lower().encode(self.encoding)

Expand Down Expand Up @@ -448,6 +499,8 @@ def __setitem__(self, key: str, value: str) -> None:
set_key = key.lower().encode(self._encoding or "utf-8")
set_value = value.encode(self._encoding or "utf-8")

self._dict[set_key] = set_value

found_indexes = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == set_key:
Expand All @@ -468,22 +521,19 @@ def __delitem__(self, key: str) -> None:
"""
del_key = key.lower().encode(self.encoding)

del self._dict[del_key]

pop_indexes = []
for idx, (item_key, _) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
if not pop_indexes:
raise KeyError(key)

for idx in reversed(pop_indexes):
del self._list[idx]

def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode(self.encoding)
for header_key, _ in self._list:
if header_key == get_header_key:
return True
return False
header_key = key.lower().encode(self.encoding)
return header_key in self._dict

def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
Expand All @@ -503,14 +553,19 @@ def __repr__(self) -> str:
if self.encoding != "ascii":
encoding_str = f", encoding={self.encoding!r}"

as_list = list(obfuscate_sensitive_headers(self.items()))
as_list = list(obfuscate_sensitive_headers(self.multi_items()))
as_dict = dict(as_list)

no_duplicate_keys = len(as_dict) == len(as_list)
if no_duplicate_keys:
return f"{class_name}({as_dict!r}{encoding_str})"
return f"{class_name}({as_list!r}{encoding_str})"

def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]:
message = "Headers.getlist() is pending deprecation. Use Headers.get_list()"
warnings.warn(message, PendingDeprecationWarning)
return self.get_list(key, split_commas=split_commas)


USER_AGENT = f"python-httpx/{__version__}"
ACCEPT_ENCODING = ", ".join(
Expand Down
17 changes: 9 additions & 8 deletions tests/models/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def test_headers():
assert h["a"] == "123, 456"
assert h.get("a") == "123, 456"
assert h.get("nope", default=None) is None
assert h.getlist("a") == ["123", "456"]
assert h.keys() == ["a", "a", "b"]
assert h.values() == ["123", "456", "789"]
assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
assert list(h) == ["a", "a", "b"]
assert h.get_list("a") == ["123", "456"]
assert h.keys() == ["a", "b"]
assert h.values() == ["123, 456", "789"]
assert h.items() == [("a", "123, 456"), ("b", "789")]
assert h.multi_items() == [("a", "123"), ("a", "456"), ("b", "789")]
assert list(h) == ["a", "b"]
assert dict(h) == {"a": "123, 456", "b": "789"}
assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])"
assert h == httpx.Headers([("a", "123"), ("b", "789"), ("a", "456")])
Expand Down Expand Up @@ -153,13 +154,13 @@ def test_headers_decode_explicit_encoding():

def test_multiple_headers():
"""
Most headers should split by commas for `getlist`, except 'Set-Cookie'.
`Headers.get_list` should support both split_commas=False and split_commas=True.
"""
h = httpx.Headers([("set-cookie", "a, b"), ("set-cookie", "c")])
h.getlist("Set-Cookie") == ["a, b", "b"]
assert h.get_list("Set-Cookie") == ["a, b", "c"]

h = httpx.Headers([("vary", "a, b"), ("vary", "c")])
h.getlist("Vary") == ["a", "b", "c"]
assert h.get_list("Vary", split_commas=True) == ["a", "b", "c"]


@pytest.mark.parametrize("header", ["authorization", "proxy-authorization"])
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_queryparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_queryparams(source):
assert q["a"] == "456"
assert q.get("a") == "456"
assert q.get("nope", default=None) is None
assert q.getlist("a") == ["123", "456"]
assert q.get_list("a") == ["123", "456"]
assert list(q.keys()) == ["a", "b"]
assert list(q.values()) == ["456", "789"]
assert list(q.items()) == [("a", "456"), ("b", "789")]
Expand Down