diff --git a/httpx/_models.py b/httpx/_models.py index c8cbfbb449..4a81e5965d 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -4,6 +4,7 @@ import json as jsonlib import typing import urllib.request +import warnings from collections.abc import MutableMapping from http.cookiejar import Cookie, CookieJar from urllib.parse import parse_qsl, urlencode @@ -240,9 +241,6 @@ def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None: self._list = [(str(k), str_query_param(v)) for k, v in items] self._dict = {str(k): str_query_param(v) for k, v in items} - def getlist(self, key: typing.Any) -> typing.List[str]: - return [item_value for item_key, item_value in self._list if item_key == key] - def keys(self) -> typing.KeysView: return self._dict.keys() @@ -250,16 +248,33 @@ def values(self) -> typing.ValuesView: return self._dict.values() def items(self) -> typing.ItemsView: + """ + Return all items in the query params. If a key occurs more than once + only the first item for that key is returned. + """ return self._dict.items() def multi_items(self) -> typing.List[typing.Tuple[str, str]]: + """ + Return all items in the query params. Allow duplicate keys to occur. + """ return list(self._list) def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + """ + Get a value from the query param for a given key. If the key occurs + more than once, then only the first value is returned. + """ if key in self._dict: return self._dict[key] return default + def get_list(self, key: typing.Any) -> typing.List[str]: + """ + Get all values from the query param for a given key. + """ + return [item_value for item_key, item_value in self._list if item_key == key] + def update(self, params: QueryParamTypes = None) -> None: if not params: return @@ -315,6 +330,13 @@ def __repr__(self) -> str: query_string = str(self) return f"{class_name}({query_string!r})" + def getlist(self, key: typing.Any) -> typing.List[str]: + message = ( + "QueryParams.getlist() is pending deprecation. Use QueryParams.get_list()" + ) + warnings.warn(message, PendingDeprecationWarning) + return self.get_list(key) + class Headers(typing.MutableMapping[str, str]): """ @@ -336,6 +358,14 @@ def __init__(self, headers: HeaderTypes = None, encoding: str = None) -> None: (normalize_header_key(k, encoding), normalize_header_value(v, encoding)) for k, v in headers ] + + self._dict = {} # type: typing.Dict[bytes, bytes] + for key, value in self._list: + if key in self._dict: + self._dict[key] = self._dict[key] + b", " + value + else: + self._dict[key] = value + self._encoding = encoding @property @@ -376,26 +406,47 @@ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: return self._list def keys(self) -> typing.List[str]: # type: ignore - return [key.decode(self.encoding) for key, value in self._list] + return [key.decode(self.encoding) for key in self._dict.keys()] def values(self) -> typing.List[str]: # type: ignore - return [value.decode(self.encoding) for key, value in self._list] + return [value.decode(self.encoding) for value in self._dict.values()] def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore + """ + Return a list of `(key, value)` pairs of headers. Concatenate headers + into a single comma seperated value when a key occurs multiple times. + """ + return [ + (key.decode(self.encoding), value.decode(self.encoding)) + for key, value in self._dict.items() + ] + + def multi_items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore + """ + Return a list of `(key, value)` pairs of headers. Allow multiple + occurences of the same key without concatenating into a single + comma seperated value. + """ return [ (key.decode(self.encoding), value.decode(self.encoding)) for key, value in self._list ] def get(self, key: str, default: typing.Any = None) -> typing.Any: + """ + Return a header value. If multiple occurences of the header occur + then concatenate them together with commas. + """ try: return self[key] except KeyError: return default - def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]: + def get_list(self, key: str, split_commas: bool = False) -> typing.List[str]: """ - Return multiple header values. + Return a list of all header values for a given key. + If `split_commas=True` is passed, then any comma seperated header + values are split into multiple return strings. """ get_header_key = key.lower().encode(self.encoding) @@ -448,6 +499,8 @@ def __setitem__(self, key: str, value: str) -> None: set_key = key.lower().encode(self._encoding or "utf-8") set_value = value.encode(self._encoding or "utf-8") + self._dict[set_key] = set_value + found_indexes = [] for idx, (item_key, _) in enumerate(self._list): if item_key == set_key: @@ -468,22 +521,19 @@ def __delitem__(self, key: str) -> None: """ del_key = key.lower().encode(self.encoding) + del self._dict[del_key] + pop_indexes = [] for idx, (item_key, _) in enumerate(self._list): if item_key == del_key: pop_indexes.append(idx) - if not pop_indexes: - raise KeyError(key) for idx in reversed(pop_indexes): del self._list[idx] def __contains__(self, key: typing.Any) -> bool: - get_header_key = key.lower().encode(self.encoding) - for header_key, _ in self._list: - if header_key == get_header_key: - return True - return False + header_key = key.lower().encode(self.encoding) + return header_key in self._dict def __iter__(self) -> typing.Iterator[typing.Any]: return iter(self.keys()) @@ -503,7 +553,7 @@ def __repr__(self) -> str: if self.encoding != "ascii": encoding_str = f", encoding={self.encoding!r}" - as_list = list(obfuscate_sensitive_headers(self.items())) + as_list = list(obfuscate_sensitive_headers(self.multi_items())) as_dict = dict(as_list) no_duplicate_keys = len(as_dict) == len(as_list) @@ -511,6 +561,11 @@ def __repr__(self) -> str: return f"{class_name}({as_dict!r}{encoding_str})" return f"{class_name}({as_list!r}{encoding_str})" + def getlist(self, key: str, split_commas: bool = False) -> typing.List[str]: + message = "Headers.getlist() is pending deprecation. Use Headers.get_list()" + warnings.warn(message, PendingDeprecationWarning) + return self.get_list(key, split_commas=split_commas) + USER_AGENT = f"python-httpx/{__version__}" ACCEPT_ENCODING = ", ".join( diff --git a/tests/models/test_headers.py b/tests/models/test_headers.py index 088f5c8a1f..ce08816d16 100644 --- a/tests/models/test_headers.py +++ b/tests/models/test_headers.py @@ -13,11 +13,12 @@ def test_headers(): assert h["a"] == "123, 456" assert h.get("a") == "123, 456" assert h.get("nope", default=None) is None - assert h.getlist("a") == ["123", "456"] - assert h.keys() == ["a", "a", "b"] - assert h.values() == ["123", "456", "789"] - assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")] - assert list(h) == ["a", "a", "b"] + assert h.get_list("a") == ["123", "456"] + assert h.keys() == ["a", "b"] + assert h.values() == ["123, 456", "789"] + assert h.items() == [("a", "123, 456"), ("b", "789")] + assert h.multi_items() == [("a", "123"), ("a", "456"), ("b", "789")] + assert list(h) == ["a", "b"] assert dict(h) == {"a": "123, 456", "b": "789"} assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])" assert h == httpx.Headers([("a", "123"), ("b", "789"), ("a", "456")]) @@ -153,13 +154,13 @@ def test_headers_decode_explicit_encoding(): def test_multiple_headers(): """ - Most headers should split by commas for `getlist`, except 'Set-Cookie'. + `Headers.get_list` should support both split_commas=False and split_commas=True. """ h = httpx.Headers([("set-cookie", "a, b"), ("set-cookie", "c")]) - h.getlist("Set-Cookie") == ["a, b", "b"] + assert h.get_list("Set-Cookie") == ["a, b", "c"] h = httpx.Headers([("vary", "a, b"), ("vary", "c")]) - h.getlist("Vary") == ["a", "b", "c"] + assert h.get_list("Vary", split_commas=True) == ["a", "b", "c"] @pytest.mark.parametrize("header", ["authorization", "proxy-authorization"]) diff --git a/tests/models/test_queryparams.py b/tests/models/test_queryparams.py index 99f193f86d..39f57e2e7f 100644 --- a/tests/models/test_queryparams.py +++ b/tests/models/test_queryparams.py @@ -19,7 +19,7 @@ def test_queryparams(source): assert q["a"] == "456" assert q.get("a") == "456" assert q.get("nope", default=None) is None - assert q.getlist("a") == ["123", "456"] + assert q.get_list("a") == ["123", "456"] assert list(q.keys()) == ["a", "b"] assert list(q.values()) == ["456", "789"] assert list(q.items()) == [("a", "456"), ("b", "789")]