diff --git a/news/12388.feature.rst b/news/12388.feature.rst new file mode 100644 index 00000000000..27f368416d5 --- /dev/null +++ b/news/12388.feature.rst @@ -0,0 +1 @@ +Add parallel download support to BatchDownloader diff --git a/src/pip/_internal/cli/cmdoptions.py b/src/pip/_internal/cli/cmdoptions.py index 8fb16dc4a6a..873e6798c8f 100644 --- a/src/pip/_internal/cli/cmdoptions.py +++ b/src/pip/_internal/cli/cmdoptions.py @@ -764,6 +764,19 @@ def _handle_no_cache_dir( help="Check the build dependencies when PEP517 is used.", ) +parallel_downloads: Callable[..., Option] = partial( + Option, + "--parallel-downloads", + dest="parallel_downloads", + type="int", + metavar="n", + default=1, + help=( + "Use upto threads to download packages in parallel." + " must be greater than 0" + ), +) + def _handle_no_use_pep517( option: Option, opt: str, value: str, parser: OptionParser diff --git a/src/pip/_internal/cli/req_command.py b/src/pip/_internal/cli/req_command.py index 6f2f79c6b3f..1245e94d866 100644 --- a/src/pip/_internal/cli/req_command.py +++ b/src/pip/_internal/cli/req_command.py @@ -119,12 +119,17 @@ def _build_session( else: ssl_context = None + if "parallel_downloads" in options.__dict__: + parallel_downloads = options.parallel_downloads + else: + parallel_downloads = 1 session = PipSession( cache=os.path.join(cache_dir, "http-v2") if cache_dir else None, retries=retries if retries is not None else options.retries, trusted_hosts=options.trusted_hosts, index_urls=self._get_index_urls(options), ssl_context=ssl_context, + parallel_downloads=parallel_downloads, ) # Handle custom ca-bundles from the user diff --git a/src/pip/_internal/commands/download.py b/src/pip/_internal/commands/download.py index 54247a78a65..62961229b02 100644 --- a/src/pip/_internal/commands/download.py +++ b/src/pip/_internal/commands/download.py @@ -7,6 +7,7 @@ from pip._internal.cli.cmdoptions import make_target_python from pip._internal.cli.req_command import RequirementCommand, with_cleanup from pip._internal.cli.status_codes import SUCCESS +from pip._internal.exceptions import CommandError from pip._internal.operations.build.build_tracker import get_build_tracker from pip._internal.req.req_install import check_legacy_setup_py_options from pip._internal.utils.misc import ensure_dir, normalize_path, write_output @@ -52,6 +53,7 @@ def add_options(self) -> None: self.cmd_opts.add_option(cmdoptions.no_use_pep517()) self.cmd_opts.add_option(cmdoptions.check_build_deps()) self.cmd_opts.add_option(cmdoptions.ignore_requires_python()) + self.cmd_opts.add_option(cmdoptions.parallel_downloads()) self.cmd_opts.add_option( "-d", @@ -76,6 +78,9 @@ def add_options(self) -> None: @with_cleanup def run(self, options: Values, args: List[str]) -> int: + if options.parallel_downloads < 1: + raise CommandError("Value of '--parallel-downloads' must be greater than 0") + options.ignore_installed = True # editable doesn't really make sense for `pip download`, but the bowels # of the RequirementSet code require that property. diff --git a/src/pip/_internal/commands/install.py b/src/pip/_internal/commands/install.py index 365764fc7cb..14d1ccc3fa6 100644 --- a/src/pip/_internal/commands/install.py +++ b/src/pip/_internal/commands/install.py @@ -74,6 +74,7 @@ def add_options(self) -> None: self.cmd_opts.add_option(cmdoptions.constraints()) self.cmd_opts.add_option(cmdoptions.no_deps()) self.cmd_opts.add_option(cmdoptions.pre()) + self.cmd_opts.add_option(cmdoptions.parallel_downloads()) self.cmd_opts.add_option(cmdoptions.editable()) self.cmd_opts.add_option( @@ -267,6 +268,8 @@ def run(self, options: Values, args: List[str]) -> int: if options.use_user_site and options.target_dir is not None: raise CommandError("Can not combine '--user' and '--target'") + if options.parallel_downloads < 1: + raise CommandError("Value of '--parallel-downloads' must be greater than 0") # Check whether the environment we're installing into is externally # managed, as specified in PEP 668. Specifying --root, --target, or # --prefix disables the check, since there's no reliable way to locate diff --git a/src/pip/_internal/network/download.py b/src/pip/_internal/network/download.py index 79b82a570e5..ec985f1fe44 100644 --- a/src/pip/_internal/network/download.py +++ b/src/pip/_internal/network/download.py @@ -4,6 +4,8 @@ import logging import mimetypes import os +from concurrent.futures import ThreadPoolExecutor +from functools import partial from typing import Iterable, Optional, Tuple from pip._vendor.requests.models import CONTENT_CHUNK_SIZE, Response @@ -119,6 +121,36 @@ def _http_get_download(session: PipSession, link: Link) -> Response: return resp +def _download( + link: Link, location: str, session: PipSession, progress_bar: str +) -> Tuple[str, str]: + """ + Common download logic across Downloader and BatchDownloader classes + + :param link: The Link object to be downloaded + :param location: path to download to + :param session: PipSession object + :param progress_bar: creates a `rich` progress bar is set to "on" + :return: the path to the downloaded file and the content-type + """ + try: + resp = _http_get_download(session, link) + except NetworkConnectionError as e: + assert e.response is not None + logger.critical("HTTP error %s while getting %s", e.response.status_code, link) + raise + + filename = _get_http_response_filename(resp, link) + filepath = os.path.join(location, filename) + + chunks = _prepare_download(resp, link, progress_bar) + with open(filepath, "wb") as content_file: + for chunk in chunks: + content_file.write(chunk) + content_type = resp.headers.get("Content-Type", "") + return filepath, content_type + + class Downloader: def __init__( self, @@ -130,24 +162,7 @@ def __init__( def __call__(self, link: Link, location: str) -> Tuple[str, str]: """Download the file given by link into location.""" - try: - resp = _http_get_download(self._session, link) - except NetworkConnectionError as e: - assert e.response is not None - logger.critical( - "HTTP error %s while getting %s", e.response.status_code, link - ) - raise - - filename = _get_http_response_filename(resp, link) - filepath = os.path.join(location, filename) - - chunks = _prepare_download(resp, link, self._progress_bar) - with open(filepath, "wb") as content_file: - for chunk in chunks: - content_file.write(chunk) - content_type = resp.headers.get("Content-Type", "") - return filepath, content_type + return _download(link, location, self._session, self._progress_bar) class BatchDownloader: @@ -159,28 +174,37 @@ def __init__( self._session = session self._progress_bar = progress_bar + def _sequential_download( + self, link: Link, location: str, progress_bar: str + ) -> Tuple[Link, Tuple[str, str]]: + filepath, content_type = _download(link, location, self._session, progress_bar) + return link, (filepath, content_type) + + def _download_parallel( + self, links: Iterable[Link], location: str, max_workers: int + ) -> Iterable[Tuple[Link, Tuple[str, str]]]: + """ + Wraps the _sequential_download method in a ThreadPoolExecutor. `rich` + progress bar doesn't support naive parallelism, hence the progress bar + is disabled for parallel downloads. For more info see PR #12388 + """ + with ThreadPoolExecutor(max_workers=max_workers) as pool: + _download_parallel = partial( + self._sequential_download, location=location, progress_bar="off" + ) + results = list(pool.map(_download_parallel, links)) + return results + def __call__( self, links: Iterable[Link], location: str ) -> Iterable[Tuple[Link, Tuple[str, str]]]: """Download the files given by links into location.""" - for link in links: - try: - resp = _http_get_download(self._session, link) - except NetworkConnectionError as e: - assert e.response is not None - logger.critical( - "HTTP error %s while getting %s", - e.response.status_code, - link, - ) - raise - - filename = _get_http_response_filename(resp, link) - filepath = os.path.join(location, filename) - - chunks = _prepare_download(resp, link, self._progress_bar) - with open(filepath, "wb") as content_file: - for chunk in chunks: - content_file.write(chunk) - content_type = resp.headers.get("Content-Type", "") - yield link, (filepath, content_type) + links = list(links) + max_workers = self._session.parallel_downloads + if max_workers == 1 or len(links) == 1: + for link in links: + yield self._sequential_download(link, location, self._progress_bar) + else: + results = self._download_parallel(links, location, max_workers) + for result in results: + yield result diff --git a/src/pip/_internal/network/session.py b/src/pip/_internal/network/session.py index 887dc14e796..821bf2f5e73 100644 --- a/src/pip/_internal/network/session.py +++ b/src/pip/_internal/network/session.py @@ -326,6 +326,7 @@ def __init__( trusted_hosts: Sequence[str] = (), index_urls: Optional[List[str]] = None, ssl_context: Optional["SSLContext"] = None, + parallel_downloads: int = 1, **kwargs: Any, ) -> None: """ @@ -362,12 +363,22 @@ def __init__( backoff_factor=0.25, ) # type: ignore + # Used to set numbers of parallel downloads in + # pip._internal.network.BatchDownloader and to set pool_connection in + # the HTTPAdapter to prevent connection pool from hitting the default(10) + # limit and throwing 'Connection pool is full' warnings + self.parallel_downloads = parallel_downloads + pool_maxsize = max(self.parallel_downloads, 10) # Our Insecure HTTPAdapter disables HTTPS validation. It does not # support caching so we'll use it for all http:// URLs. # If caching is disabled, we will also use it for # https:// hosts that we've marked as ignoring # TLS errors for (trusted-hosts). - insecure_adapter = InsecureHTTPAdapter(max_retries=retries) + insecure_adapter = InsecureHTTPAdapter( + max_retries=retries, + pool_connections=pool_maxsize, + pool_maxsize=pool_maxsize, + ) # We want to _only_ cache responses on securely fetched origins or when # the host is specified as trusted. We do this because @@ -385,7 +396,12 @@ def __init__( max_retries=retries, ) else: - secure_adapter = HTTPAdapter(max_retries=retries, ssl_context=ssl_context) + secure_adapter = HTTPAdapter( + max_retries=retries, + ssl_context=ssl_context, + pool_connections=pool_maxsize, + pool_maxsize=pool_maxsize, + ) self._trusted_host_adapter = insecure_adapter self.mount("https://", secure_adapter)