diff --git a/src/scmrepo/git/lfs/client.py b/src/scmrepo/git/lfs/client.py index ea1554a4..05bed28b 100644 --- a/src/scmrepo/git/lfs/client.py +++ b/src/scmrepo/git/lfs/client.py @@ -3,7 +3,6 @@ import shutil from collections.abc import Iterable, Iterator from contextlib import AbstractContextManager, contextmanager, suppress -from multiprocessing import cpu_count from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING, Any, Optional @@ -32,7 +31,6 @@ class LFSClient(AbstractContextManager): JSON_CONTENT_TYPE = "application/vnd.git-lfs+json" - _JOBS = 4 * cpu_count() _REQUEST_TIMEOUT = 60 _SESSION_RETRIES = 5 _SESSION_BACKOFF_FACTOR = 0.1 @@ -154,6 +152,7 @@ async def _download( storage: "LFSStorage", objects: Iterable[Pointer], callback: "Callback" = DEFAULT_CALLBACK, + batch_size: Optional[int] = None, **kwargs, ): async def _get_one(from_path: str, to_path: str, **kwargs): @@ -179,7 +178,7 @@ async def _get_one(from_path: str, to_path: str, **kwargs): to_path = storage.oid_to_path(obj.oid) coros.append(_get_one(url, to_path, headers=headers)) for result in await _run_coros_in_chunks( - coros, batch_size=self._JOBS, return_exceptions=True + coros, batch_size=batch_size, return_exceptions=True ): if isinstance(result, BaseException): raise result diff --git a/src/scmrepo/git/lfs/smudge.py b/src/scmrepo/git/lfs/smudge.py index dcf7e0ed..b426e5be 100644 --- a/src/scmrepo/git/lfs/smudge.py +++ b/src/scmrepo/git/lfs/smudge.py @@ -11,7 +11,10 @@ def smudge( - storage: "LFSStorage", fobj: BinaryIO, url: Optional[str] = None + storage: "LFSStorage", + fobj: BinaryIO, + url: Optional[str] = None, + batch_size: Optional[int] = None, ) -> BinaryIO: """Wrap the specified binary IO stream and run LFS smudge if necessary.""" reader = io.BufferedReader(fobj) # type: ignore[arg-type] diff --git a/src/scmrepo/git/lfs/storage.py b/src/scmrepo/git/lfs/storage.py index ff472cc2..d4fe4e07 100644 --- a/src/scmrepo/git/lfs/storage.py +++ b/src/scmrepo/git/lfs/storage.py @@ -20,13 +20,14 @@ def fetch( url: str, objects: Collection[Pointer], progress: Optional[Callable[["GitProgressEvent"], None]] = None, + batch_size: Optional[int] = None, ): from .client import LFSClient with LFSCallback.as_lfs_callback(progress) as cb: cb.set_size(len(objects)) with LFSClient.from_git_url(url) as client: - client.download(self, objects, callback=cb) + client.download(self, objects, callback=cb, batch_size=batch_size) def oid_to_path(self, oid: str): return os.path.join(self.path, "objects", oid[0:2], oid[2:4], oid) @@ -40,6 +41,7 @@ def open( self, obj: Union[Pointer, str], fetch_url: Optional[str] = None, + batch_size: Optional[int] = None, **kwargs, ) -> BinaryIO: oid = obj if isinstance(obj, str) else obj.oid @@ -50,7 +52,7 @@ def open( if not fetch_url or not isinstance(obj, Pointer): raise try: - self.fetch(fetch_url, [obj]) + self.fetch(fetch_url, [obj], batch_size=batch_size) except BaseException as exc: # noqa: BLE001 raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), path