From 61006b8dec343a9c5f7df659da1fc252e89dffe2 Mon Sep 17 00:00:00 2001 From: Chenglong Hao Date: Wed, 25 Oct 2023 15:23:00 -0700 Subject: [PATCH 1/6] pass in requests.Session --- tests/main_test.py | 19 +++++++++++++ tldextract/suffix_list.py | 17 ++++++++++-- tldextract/tldextract.py | 57 ++++++++++++++++++++++++++++----------- 3 files changed, 75 insertions(+), 18 deletions(-) diff --git a/tests/main_test.py b/tests/main_test.py index bf6f7a79..64bccde8 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -11,6 +11,7 @@ import pytest import pytest_mock +import requests import responses import tldextract @@ -449,6 +450,24 @@ def test_cache_timeouts(tmp_path: Path) -> None: tldextract.suffix_list.find_first_response(cache, [server], 5) +@responses.activate +def test_find_first_response(tmp_path: Path) -> None: + """Test it is able to find first response.""" + server = "http://some-server.com" + response_text = "server response" + responses.add(responses.GET, server, status=200, body=response_text) + cache = DiskCache(str(tmp_path)) + + # without session passed in + result = tldextract.suffix_list.find_first_response(cache, [server], 5) + assert result == response_text + + # with session passed in + session = requests.Session() + result = tldextract.suffix_list.find_first_response(cache, [server], 5, session) + assert result == response_text + + def test_include_psl_private_domain_attr() -> None: """Test private domains, which default to not being treated differently.""" extract_private = tldextract.TLDExtract(include_psl_private_domains=True) diff --git a/tldextract/suffix_list.py b/tldextract/suffix_list.py index 62427367..cb6f4fc6 100644 --- a/tldextract/suffix_list.py +++ b/tldextract/suffix_list.py @@ -31,11 +31,15 @@ def find_first_response( cache: DiskCache, urls: Sequence[str], cache_fetch_timeout: float | int | None = None, + session: requests.Session = None, ) -> str: """Decode the first successfully fetched URL, from UTF-8 encoding to Python unicode.""" - with requests.Session() as session: + + if session is None: + session = requests.Session() session.mount("file://", FileAdapter()) + try: for url in urls: try: return cache.cached_fetch_url( @@ -43,6 +47,10 @@ def find_first_response( ) except requests.exceptions.RequestException: LOG.exception("Exception reading Public Suffix List url %s", url) + finally: + # Ensure the session is always closed + session.close() + raise SuffixListNotFound( "No remote Public Suffix List found. Consider using a mirror, or avoid this" " fetch by constructing your TLDExtract with `suffix_list_urls=()`." @@ -65,6 +73,7 @@ def get_suffix_lists( urls: Sequence[str], cache_fetch_timeout: float | int | None, fallback_to_snapshot: bool, + session: requests.Session = None, ) -> tuple[list[str], list[str]]: """Fetch, parse, and cache the suffix lists.""" return cache.run_and_cache( @@ -75,6 +84,7 @@ def get_suffix_lists( "urls": urls, "cache_fetch_timeout": cache_fetch_timeout, "fallback_to_snapshot": fallback_to_snapshot, + "session": session, }, hashed_argnames=["urls", "fallback_to_snapshot"], ) @@ -85,10 +95,13 @@ def _get_suffix_lists( urls: Sequence[str], cache_fetch_timeout: float | int | None, fallback_to_snapshot: bool, + session: requests.Session = None, ) -> tuple[list[str], list[str]]: """Fetch, parse, and cache the suffix lists.""" try: - text = find_first_response(cache, urls, cache_fetch_timeout=cache_fetch_timeout) + text = find_first_response( + cache, urls, cache_fetch_timeout=cache_fetch_timeout, session=session + ) except SuffixListNotFound as exc: if fallback_to_snapshot: maybe_pkg_data = pkgutil.get_data("tldextract", ".tld_set_snapshot") diff --git a/tldextract/tldextract.py b/tldextract/tldextract.py index 95a7acd0..d66ea1ec 100644 --- a/tldextract/tldextract.py +++ b/tldextract/tldextract.py @@ -44,6 +44,7 @@ from functools import wraps import idna +import requests from .cache import DiskCache, get_cache_dir from .remote import lenient_netloc, looks_like_ip, looks_like_ipv6 @@ -172,6 +173,9 @@ def __init__( its mirror, but any similar document could be specified. Local files can be specified by using the `file://` protocol. (See `urllib2` documentation.) + If you need proxy to access the URLs in `suffix_list_urls`, an optional `session` can be + passed in with proxy configured. + If there is no cached version loaded and no data is found from the `suffix_list_urls`, the module will fall back to the included TLD set snapshot. If you do not want this behavior, you may set `fallback_to_snapshot` to False, and an exception will be @@ -221,13 +225,19 @@ def __init__( self._cache = DiskCache(cache_dir) def __call__( - self, url: str, include_psl_private_domains: bool | None = None + self, + url: str, + include_psl_private_domains: bool | None = None, + session: requests.Session = None, ) -> ExtractResult: """Alias for `extract_str`.""" - return self.extract_str(url, include_psl_private_domains) + return self.extract_str(url, include_psl_private_domains, session=session) def extract_str( - self, url: str, include_psl_private_domains: bool | None = None + self, + url: str, + include_psl_private_domains: bool | None = None, + session: requests.Session = None, ) -> ExtractResult: """Take a string URL and splits it into its subdomain, domain, and suffix components. @@ -239,12 +249,15 @@ def extract_str( >>> extractor.extract_str('http://forums.bbc.co.uk/') ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk', is_private=False) """ - return self._extract_netloc(lenient_netloc(url), include_psl_private_domains) + return self._extract_netloc( + lenient_netloc(url), include_psl_private_domains, session=session + ) def extract_urllib( self, url: urllib.parse.ParseResult | urllib.parse.SplitResult, include_psl_private_domains: bool | None = None, + session: requests.Session = None, ) -> ExtractResult: """Take the output of urllib.parse URL parsing methods and further splits the parsed URL. @@ -260,10 +273,15 @@ def extract_urllib( >>> extractor.extract_urllib(urllib.parse.urlsplit('http://forums.bbc.co.uk/')) ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk', is_private=False) """ - return self._extract_netloc(url.netloc, include_psl_private_domains) + return self._extract_netloc( + url.netloc, include_psl_private_domains, session=session + ) def _extract_netloc( - self, netloc: str, include_psl_private_domains: bool | None + self, + netloc: str, + include_psl_private_domains: bool | None, + session: requests.Session = None, ) -> ExtractResult: netloc_with_ascii_dots = ( netloc.replace("\u3002", "\u002e") @@ -282,9 +300,9 @@ def _extract_netloc( labels = netloc_with_ascii_dots.split(".") - suffix_index, is_private = self._get_tld_extractor().suffix_index( - labels, include_psl_private_domains=include_psl_private_domains - ) + suffix_index, is_private = self._get_tld_extractor( + session=session + ).suffix_index(labels, include_psl_private_domains=include_psl_private_domains) num_ipv4_labels = 4 if suffix_index == len(labels) == num_ipv4_labels and looks_like_ip( @@ -297,23 +315,25 @@ def _extract_netloc( domain = labels[suffix_index - 1] if suffix_index else "" return ExtractResult(subdomain, domain, suffix, is_private) - def update(self, fetch_now: bool = False) -> None: + def update(self, fetch_now: bool = False, session: requests.Session = None) -> None: """Force fetch the latest suffix list definitions.""" self._extractor = None self._cache.clear() if fetch_now: - self._get_tld_extractor() + self._get_tld_extractor(session=session) @property - def tlds(self) -> list[str]: + def tlds(self, session: requests.Session = None) -> list[str]: """ Returns the list of tld's used by default. This will vary based on `include_psl_private_domains` and `extra_suffixes` """ - return list(self._get_tld_extractor().tlds()) + return list(self._get_tld_extractor(session=session).tlds()) - def _get_tld_extractor(self) -> _PublicSuffixListTLDExtractor: + def _get_tld_extractor( + self, session: requests.Session = None + ) -> _PublicSuffixListTLDExtractor: """Get or compute this object's TLDExtractor. Looks up the TLDExtractor in roughly the following order, based on the @@ -332,6 +352,7 @@ def _get_tld_extractor(self) -> _PublicSuffixListTLDExtractor: urls=self.suffix_list_urls, cache_fetch_timeout=self.cache_fetch_timeout, fallback_to_snapshot=self.fallback_to_snapshot, + session=session, ) if not any([public_tlds, private_tlds, self.extra_suffixes]): @@ -400,9 +421,13 @@ def add_suffix(self, suffix: str, is_private: bool = False) -> None: @wraps(TLD_EXTRACTOR.__call__) def extract( # noqa: D103 - url: str, include_psl_private_domains: bool | None = False + url: str, + include_psl_private_domains: bool | None = False, + session: requests.Session = None, ) -> ExtractResult: - return TLD_EXTRACTOR(url, include_psl_private_domains=include_psl_private_domains) + return TLD_EXTRACTOR( + url, include_psl_private_domains=include_psl_private_domains, session=session + ) @wraps(TLD_EXTRACTOR.update) From 1a7b7d88f1e1a3346aea10dfc0cd9d348e6c4021 Mon Sep 17 00:00:00 2001 From: Chenglong Hao Date: Wed, 25 Oct 2023 16:04:20 -0700 Subject: [PATCH 2/6] fix build issues --- tldextract/suffix_list.py | 7 +++---- tldextract/tldextract.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tldextract/suffix_list.py b/tldextract/suffix_list.py index cb6f4fc6..d00879c4 100644 --- a/tldextract/suffix_list.py +++ b/tldextract/suffix_list.py @@ -31,10 +31,9 @@ def find_first_response( cache: DiskCache, urls: Sequence[str], cache_fetch_timeout: float | int | None = None, - session: requests.Session = None, + session: requests.Session | None = None, ) -> str: """Decode the first successfully fetched URL, from UTF-8 encoding to Python unicode.""" - if session is None: session = requests.Session() session.mount("file://", FileAdapter()) @@ -73,7 +72,7 @@ def get_suffix_lists( urls: Sequence[str], cache_fetch_timeout: float | int | None, fallback_to_snapshot: bool, - session: requests.Session = None, + session: requests.Session | None = None, ) -> tuple[list[str], list[str]]: """Fetch, parse, and cache the suffix lists.""" return cache.run_and_cache( @@ -95,7 +94,7 @@ def _get_suffix_lists( urls: Sequence[str], cache_fetch_timeout: float | int | None, fallback_to_snapshot: bool, - session: requests.Session = None, + session: requests.Session | None = None, ) -> tuple[list[str], list[str]]: """Fetch, parse, and cache the suffix lists.""" try: diff --git a/tldextract/tldextract.py b/tldextract/tldextract.py index d66ea1ec..b9a8dc98 100644 --- a/tldextract/tldextract.py +++ b/tldextract/tldextract.py @@ -228,7 +228,7 @@ def __call__( self, url: str, include_psl_private_domains: bool | None = None, - session: requests.Session = None, + session: requests.Session | None = None, ) -> ExtractResult: """Alias for `extract_str`.""" return self.extract_str(url, include_psl_private_domains, session=session) @@ -237,7 +237,7 @@ def extract_str( self, url: str, include_psl_private_domains: bool | None = None, - session: requests.Session = None, + session: requests.Session | None = None, ) -> ExtractResult: """Take a string URL and splits it into its subdomain, domain, and suffix components. @@ -257,7 +257,7 @@ def extract_urllib( self, url: urllib.parse.ParseResult | urllib.parse.SplitResult, include_psl_private_domains: bool | None = None, - session: requests.Session = None, + session: requests.Session | None = None, ) -> ExtractResult: """Take the output of urllib.parse URL parsing methods and further splits the parsed URL. @@ -281,7 +281,7 @@ def _extract_netloc( self, netloc: str, include_psl_private_domains: bool | None, - session: requests.Session = None, + session: requests.Session | None = None, ) -> ExtractResult: netloc_with_ascii_dots = ( netloc.replace("\u3002", "\u002e") @@ -315,7 +315,9 @@ def _extract_netloc( domain = labels[suffix_index - 1] if suffix_index else "" return ExtractResult(subdomain, domain, suffix, is_private) - def update(self, fetch_now: bool = False, session: requests.Session = None) -> None: + def update( + self, fetch_now: bool = False, session: requests.Session | None = None + ) -> None: """Force fetch the latest suffix list definitions.""" self._extractor = None self._cache.clear() @@ -323,7 +325,7 @@ def update(self, fetch_now: bool = False, session: requests.Session = None) -> N self._get_tld_extractor(session=session) @property - def tlds(self, session: requests.Session = None) -> list[str]: + def tlds(self, session: requests.Session | None = None) -> list[str]: """ Returns the list of tld's used by default. @@ -332,7 +334,7 @@ def tlds(self, session: requests.Session = None) -> list[str]: return list(self._get_tld_extractor(session=session).tlds()) def _get_tld_extractor( - self, session: requests.Session = None + self, session: requests.Session | None = None ) -> _PublicSuffixListTLDExtractor: """Get or compute this object's TLDExtractor. @@ -423,7 +425,7 @@ def add_suffix(self, suffix: str, is_private: bool = False) -> None: def extract( # noqa: D103 url: str, include_psl_private_domains: bool | None = False, - session: requests.Session = None, + session: requests.Session | None = None, ) -> ExtractResult: return TLD_EXTRACTOR( url, include_psl_private_domains=include_psl_private_domains, session=session From 5efec156b82c6ac65e69d0489d394664dabaa713 Mon Sep 17 00:00:00 2001 From: John Kurkowski Date: Thu, 26 Oct 2023 11:37:57 -0700 Subject: [PATCH 3/6] Colocate docs --- tldextract/tldextract.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tldextract/tldextract.py b/tldextract/tldextract.py index b9a8dc98..e791074f 100644 --- a/tldextract/tldextract.py +++ b/tldextract/tldextract.py @@ -173,9 +173,6 @@ def __init__( its mirror, but any similar document could be specified. Local files can be specified by using the `file://` protocol. (See `urllib2` documentation.) - If you need proxy to access the URLs in `suffix_list_urls`, an optional `session` can be - passed in with proxy configured. - If there is no cached version loaded and no data is found from the `suffix_list_urls`, the module will fall back to the included TLD set snapshot. If you do not want this behavior, you may set `fallback_to_snapshot` to False, and an exception will be @@ -248,6 +245,10 @@ def extract_str( ExtractResult(subdomain='forums.news', domain='cnn', suffix='com', is_private=False) >>> extractor.extract_str('http://forums.bbc.co.uk/') ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk', is_private=False) + + Allows configuring the HTTP request via the optional `session` + parameter. For example, if you need to use a HTTP proxy. See also + `requests.Session`. """ return self._extract_netloc( lenient_netloc(url), include_psl_private_domains, session=session From 52514ee149be243a371f63092d3568c14d469654 Mon Sep 17 00:00:00 2001 From: Chenglong Hao Date: Thu, 26 Oct 2023 17:14:09 -0700 Subject: [PATCH 4/6] address comments --- tests/main_test.py | 22 +++++++++++++++++----- tldextract/suffix_list.py | 7 +++++-- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/main_test.py b/tests/main_test.py index 64bccde8..782faca7 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -8,6 +8,7 @@ from collections.abc import Sequence from pathlib import Path from typing import Any +from unittest.mock import Mock import pytest import pytest_mock @@ -451,8 +452,8 @@ def test_cache_timeouts(tmp_path: Path) -> None: @responses.activate -def test_find_first_response(tmp_path: Path) -> None: - """Test it is able to find first response.""" +def test_find_first_response_without_session(tmp_path: Path) -> None: + """Test it is able to find first response without session passed in.""" server = "http://some-server.com" response_text = "server response" responses.add(responses.GET, server, status=200, body=response_text) @@ -462,10 +463,21 @@ def test_find_first_response(tmp_path: Path) -> None: result = tldextract.suffix_list.find_first_response(cache, [server], 5) assert result == response_text - # with session passed in - session = requests.Session() - result = tldextract.suffix_list.find_first_response(cache, [server], 5, session) + +def test_find_first_response_with_session(tmp_path: Path) -> None: + """Test it is able to find first response with passed in session.""" + server = "http://some-server.com" + response_text = "server response" + cache = DiskCache(str(tmp_path)) + mock_session = Mock() + mock_session.get.return_value.text = response_text + + result = tldextract.suffix_list.find_first_response( + cache, [server], 5, mock_session + ) assert result == response_text + mock_session.get.assert_called_once_with(server, timeout=5) + mock_session.close.assert_not_called() def test_include_psl_private_domain_attr() -> None: diff --git a/tldextract/suffix_list.py b/tldextract/suffix_list.py index d00879c4..192f6333 100644 --- a/tldextract/suffix_list.py +++ b/tldextract/suffix_list.py @@ -34,9 +34,11 @@ def find_first_response( session: requests.Session | None = None, ) -> str: """Decode the first successfully fetched URL, from UTF-8 encoding to Python unicode.""" + session_created = False if session is None: session = requests.Session() session.mount("file://", FileAdapter()) + session_created = True try: for url in urls: @@ -47,8 +49,9 @@ def find_first_response( except requests.exceptions.RequestException: LOG.exception("Exception reading Public Suffix List url %s", url) finally: - # Ensure the session is always closed - session.close() + # Ensure the session is always closed if it's constructed in the method + if session_created: + session.close() raise SuffixListNotFound( "No remote Public Suffix List found. Consider using a mirror, or avoid this" From 80b14f8c8fa472770cab538d3c13c5055b860dc0 Mon Sep 17 00:00:00 2001 From: Chenglong Hao Date: Thu, 26 Oct 2023 18:28:23 -0700 Subject: [PATCH 5/6] remove unused import --- tests/main_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/main_test.py b/tests/main_test.py index 782faca7..996883ee 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -12,7 +12,6 @@ import pytest import pytest_mock -import requests import responses import tldextract From a38dfb5d061eccabdc11d96ccf462335c992569b Mon Sep 17 00:00:00 2001 From: John Kurkowski Date: Sat, 28 Oct 2023 02:07:05 -0700 Subject: [PATCH 6/6] fixup! Colocate docs --- tests/main_test.py | 1 - tldextract/tldextract.py | 7 +++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/main_test.py b/tests/main_test.py index 996883ee..c4050f7d 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -458,7 +458,6 @@ def test_find_first_response_without_session(tmp_path: Path) -> None: responses.add(responses.GET, server, status=200, body=response_text) cache = DiskCache(str(tmp_path)) - # without session passed in result = tldextract.suffix_list.find_first_response(cache, [server], 5) assert result == response_text diff --git a/tldextract/tldextract.py b/tldextract/tldextract.py index e791074f..902cae69 100644 --- a/tldextract/tldextract.py +++ b/tldextract/tldextract.py @@ -249,6 +249,13 @@ def extract_str( Allows configuring the HTTP request via the optional `session` parameter. For example, if you need to use a HTTP proxy. See also `requests.Session`. + + >>> import requests + >>> session = requests.Session() + >>> # customize your session here + >>> with session: + ... extractor.extract_str("http://forums.news.cnn.com/", session=session) + ExtractResult(subdomain='forums.news', domain='cnn', suffix='com', is_private=False) """ return self._extract_netloc( lenient_netloc(url), include_psl_private_domains, session=session