Skip to content

Commit

Permalink
Allow passing in requests.Session (#311)
Browse files Browse the repository at this point in the history
Closes #158.

---------

Co-authored-by: Chenglong Hao <[email protected]>
Co-authored-by: John Kurkowski <[email protected]>
  • Loading branch information
3 people authored Oct 28, 2023
1 parent a540ca2 commit ea2a571
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 18 deletions.
29 changes: 29 additions & 0 deletions tests/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Sequence
from pathlib import Path
from typing import Any
from unittest.mock import Mock

import pytest
import pytest_mock
Expand Down Expand Up @@ -449,6 +450,34 @@ def test_cache_timeouts(tmp_path: Path) -> None:
tldextract.suffix_list.find_first_response(cache, [server], 5)


@responses.activate
def test_find_first_response_without_session(tmp_path: Path) -> None:
"""Test it is able to find first response without session passed in."""
server = "http://some-server.com"
response_text = "server response"
responses.add(responses.GET, server, status=200, body=response_text)
cache = DiskCache(str(tmp_path))

result = tldextract.suffix_list.find_first_response(cache, [server], 5)
assert result == response_text


def test_find_first_response_with_session(tmp_path: Path) -> None:
"""Test it is able to find first response with passed in session."""
server = "http://some-server.com"
response_text = "server response"
cache = DiskCache(str(tmp_path))
mock_session = Mock()
mock_session.get.return_value.text = response_text

result = tldextract.suffix_list.find_first_response(
cache, [server], 5, mock_session
)
assert result == response_text
mock_session.get.assert_called_once_with(server, timeout=5)
mock_session.close.assert_not_called()


def test_include_psl_private_domain_attr() -> None:
"""Test private domains, which default to not being treated differently."""
extract_private = tldextract.TLDExtract(include_psl_private_domains=True)
Expand Down
19 changes: 17 additions & 2 deletions tldextract/suffix_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,28 @@ def find_first_response(
cache: DiskCache,
urls: Sequence[str],
cache_fetch_timeout: float | int | None = None,
session: requests.Session | None = None,
) -> str:
"""Decode the first successfully fetched URL, from UTF-8 encoding to Python unicode."""
with requests.Session() as session:
session_created = False
if session is None:
session = requests.Session()
session.mount("file://", FileAdapter())
session_created = True

try:
for url in urls:
try:
return cache.cached_fetch_url(
session=session, url=url, timeout=cache_fetch_timeout
)
except requests.exceptions.RequestException:
LOG.exception("Exception reading Public Suffix List url %s", url)
finally:
# Ensure the session is always closed if it's constructed in the method
if session_created:
session.close()

raise SuffixListNotFound(
"No remote Public Suffix List found. Consider using a mirror, or avoid this"
" fetch by constructing your TLDExtract with `suffix_list_urls=()`."
Expand All @@ -65,6 +75,7 @@ def get_suffix_lists(
urls: Sequence[str],
cache_fetch_timeout: float | int | None,
fallback_to_snapshot: bool,
session: requests.Session | None = None,
) -> tuple[list[str], list[str]]:
"""Fetch, parse, and cache the suffix lists."""
return cache.run_and_cache(
Expand All @@ -75,6 +86,7 @@ def get_suffix_lists(
"urls": urls,
"cache_fetch_timeout": cache_fetch_timeout,
"fallback_to_snapshot": fallback_to_snapshot,
"session": session,
},
hashed_argnames=["urls", "fallback_to_snapshot"],
)
Expand All @@ -85,10 +97,13 @@ def _get_suffix_lists(
urls: Sequence[str],
cache_fetch_timeout: float | int | None,
fallback_to_snapshot: bool,
session: requests.Session | None = None,
) -> tuple[list[str], list[str]]:
"""Fetch, parse, and cache the suffix lists."""
try:
text = find_first_response(cache, urls, cache_fetch_timeout=cache_fetch_timeout)
text = find_first_response(
cache, urls, cache_fetch_timeout=cache_fetch_timeout, session=session
)
except SuffixListNotFound as exc:
if fallback_to_snapshot:
maybe_pkg_data = pkgutil.get_data("tldextract", ".tld_set_snapshot")
Expand Down
67 changes: 51 additions & 16 deletions tldextract/tldextract.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from functools import wraps

import idna
import requests

from .cache import DiskCache, get_cache_dir
from .remote import lenient_netloc, looks_like_ip, looks_like_ipv6
Expand Down Expand Up @@ -221,13 +222,19 @@ def __init__(
self._cache = DiskCache(cache_dir)

def __call__(
self, url: str, include_psl_private_domains: bool | None = None
self,
url: str,
include_psl_private_domains: bool | None = None,
session: requests.Session | None = None,
) -> ExtractResult:
"""Alias for `extract_str`."""
return self.extract_str(url, include_psl_private_domains)
return self.extract_str(url, include_psl_private_domains, session=session)

def extract_str(
self, url: str, include_psl_private_domains: bool | None = None
self,
url: str,
include_psl_private_domains: bool | None = None,
session: requests.Session | None = None,
) -> ExtractResult:
"""Take a string URL and splits it into its subdomain, domain, and suffix components.
Expand All @@ -238,13 +245,27 @@ def extract_str(
ExtractResult(subdomain='forums.news', domain='cnn', suffix='com', is_private=False)
>>> extractor.extract_str('http://forums.bbc.co.uk/')
ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk', is_private=False)
Allows configuring the HTTP request via the optional `session`
parameter. For example, if you need to use a HTTP proxy. See also
`requests.Session`.
>>> import requests
>>> session = requests.Session()
>>> # customize your session here
>>> with session:
... extractor.extract_str("http://forums.news.cnn.com/", session=session)
ExtractResult(subdomain='forums.news', domain='cnn', suffix='com', is_private=False)
"""
return self._extract_netloc(lenient_netloc(url), include_psl_private_domains)
return self._extract_netloc(
lenient_netloc(url), include_psl_private_domains, session=session
)

def extract_urllib(
self,
url: urllib.parse.ParseResult | urllib.parse.SplitResult,
include_psl_private_domains: bool | None = None,
session: requests.Session | None = None,
) -> ExtractResult:
"""Take the output of urllib.parse URL parsing methods and further splits the parsed URL.
Expand All @@ -260,10 +281,15 @@ def extract_urllib(
>>> extractor.extract_urllib(urllib.parse.urlsplit('http://forums.bbc.co.uk/'))
ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk', is_private=False)
"""
return self._extract_netloc(url.netloc, include_psl_private_domains)
return self._extract_netloc(
url.netloc, include_psl_private_domains, session=session
)

def _extract_netloc(
self, netloc: str, include_psl_private_domains: bool | None
self,
netloc: str,
include_psl_private_domains: bool | None,
session: requests.Session | None = None,
) -> ExtractResult:
netloc_with_ascii_dots = (
netloc.replace("\u3002", "\u002e")
Expand All @@ -282,9 +308,9 @@ def _extract_netloc(

labels = netloc_with_ascii_dots.split(".")

suffix_index, is_private = self._get_tld_extractor().suffix_index(
labels, include_psl_private_domains=include_psl_private_domains
)
suffix_index, is_private = self._get_tld_extractor(
session=session
).suffix_index(labels, include_psl_private_domains=include_psl_private_domains)

num_ipv4_labels = 4
if suffix_index == len(labels) == num_ipv4_labels and looks_like_ip(
Expand All @@ -297,23 +323,27 @@ def _extract_netloc(
domain = labels[suffix_index - 1] if suffix_index else ""
return ExtractResult(subdomain, domain, suffix, is_private)

def update(self, fetch_now: bool = False) -> None:
def update(
self, fetch_now: bool = False, session: requests.Session | None = None
) -> None:
"""Force fetch the latest suffix list definitions."""
self._extractor = None
self._cache.clear()
if fetch_now:
self._get_tld_extractor()
self._get_tld_extractor(session=session)

@property
def tlds(self) -> list[str]:
def tlds(self, session: requests.Session | None = None) -> list[str]:
"""
Returns the list of tld's used by default.
This will vary based on `include_psl_private_domains` and `extra_suffixes`
"""
return list(self._get_tld_extractor().tlds())
return list(self._get_tld_extractor(session=session).tlds())

def _get_tld_extractor(self) -> _PublicSuffixListTLDExtractor:
def _get_tld_extractor(
self, session: requests.Session | None = None
) -> _PublicSuffixListTLDExtractor:
"""Get or compute this object's TLDExtractor.
Looks up the TLDExtractor in roughly the following order, based on the
Expand All @@ -332,6 +362,7 @@ def _get_tld_extractor(self) -> _PublicSuffixListTLDExtractor:
urls=self.suffix_list_urls,
cache_fetch_timeout=self.cache_fetch_timeout,
fallback_to_snapshot=self.fallback_to_snapshot,
session=session,
)

if not any([public_tlds, private_tlds, self.extra_suffixes]):
Expand Down Expand Up @@ -400,9 +431,13 @@ def add_suffix(self, suffix: str, is_private: bool = False) -> None:

@wraps(TLD_EXTRACTOR.__call__)
def extract( # noqa: D103
url: str, include_psl_private_domains: bool | None = False
url: str,
include_psl_private_domains: bool | None = False,
session: requests.Session | None = None,
) -> ExtractResult:
return TLD_EXTRACTOR(url, include_psl_private_domains=include_psl_private_domains)
return TLD_EXTRACTOR(
url, include_psl_private_domains=include_psl_private_domains, session=session
)


@wraps(TLD_EXTRACTOR.update)
Expand Down

0 comments on commit ea2a571

Please sign in to comment.