diff --git a/dvc/config_schema.py b/dvc/config_schema.py index a94b50f921..0537eb8aa6 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -138,6 +138,7 @@ class RelPath(str): "region": str, "profile": str, "credentialpath": str, + "configpath": str, "endpointurl": str, "access_key_id": str, "secret_access_key": str, diff --git a/dvc/tree/s3.py b/dvc/tree/s3.py index 568b7e83ce..e6537fb353 100644 --- a/dvc/tree/s3.py +++ b/dvc/tree/s3.py @@ -11,12 +11,14 @@ from dvc.path_info import CloudURLInfo from dvc.progress import Tqdm from dvc.scheme import Schemes -from dvc.utils import error_link +from dvc.utils import conversions, error_link from .base import BaseTree logger = logging.getLogger(__name__) +_AWS_CONFIG_PATH = os.path.join(os.path.expanduser("~"), ".aws", "config") + class S3Tree(BaseTree): scheme = Schemes.S3 @@ -61,6 +63,56 @@ def __init__(self, repo, config): if shared_creds: os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds) + config_path = config.get("configpath") + if config_path: + os.environ.setdefault("AWS_CONFIG_FILE", config_path) + self._transfer_config = None + + # https://github.com/aws/aws-cli/blob/0376c6262d6b15dc36c82e6da6e1aad10249cc8c/awscli/customizations/s3/transferconfig.py#L107-L113 + _TRANSFER_CONFIG_ALIASES = { + "max_queue_size": "max_io_queue", + "max_concurrent_requests": "max_concurrency", + "multipart_threshold": "multipart_threshold", + "multipart_chunksize": "multipart_chunksize", + } + + def _transform_config(self, s3_config): + """Splits the general s3 config into 2 different config + objects, one for transfer.TransferConfig and other is the + general session config""" + + config, transfer_config = {}, {} + for key, value in s3_config.items(): + if key in self._TRANSFER_CONFIG_ALIASES: + if key in {"multipart_chunksize", "multipart_threshold"}: + # cast human readable sizes (like 24MiB) to integers + value = conversions.human_readable_to_bytes(value) + else: + value = int(value) + transfer_config[self._TRANSFER_CONFIG_ALIASES[key]] = value + else: + config[key] = value + + return config, transfer_config + + def _process_config(self): + from boto3.s3.transfer import TransferConfig + from botocore.configloader import load_config + + config_path = os.environ.get("AWS_CONFIG_FILE", _AWS_CONFIG_PATH) + if not os.path.exists(config_path): + return None + + config = load_config(config_path) + profile = config["profiles"].get(self.profile or "default") + if not profile: + return None + + s3_config = profile.get("s3", {}) + s3_config, transfer_config = self._transform_config(s3_config) + self._transfer_config = TransferConfig(**transfer_config) + return s3_config + @wrap_prop(threading.Lock()) @cached_property def s3(self): @@ -79,12 +131,15 @@ def s3(self): session_opts["aws_session_token"] = self.session_token session = boto3.session.Session(**session_opts) + s3_config = self._process_config() return session.resource( "s3", endpoint_url=self.endpoint_url, use_ssl=self.use_ssl, - config=boto3.session.Config(signature_version="s3v4"), + config=boto3.session.Config( + signature_version="s3v4", s3=s3_config + ), ) @contextmanager @@ -373,7 +428,7 @@ def get_file_hash(self, path_info): def _upload_fobj(self, fobj, to_info): with self._get_obj(to_info) as obj: - obj.upload_fileobj(fobj) + obj.upload_fileobj(fobj, Config=self._transfer_config) def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs @@ -384,7 +439,10 @@ def _upload( disable=no_progress_bar, total=total, bytes=True, desc=name ) as pbar: obj.upload_file( - from_file, Callback=pbar.update, ExtraArgs=self.extra_args, + from_file, + Callback=pbar.update, + ExtraArgs=self.extra_args, + Config=self._transfer_config, ) def _download(self, from_info, to_file, name=None, no_progress_bar=False): @@ -395,4 +453,6 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False): bytes=True, desc=name, ) as pbar: - obj.download_file(to_file, Callback=pbar.update) + obj.download_file( + to_file, Callback=pbar.update, Config=self._transfer_config + ) diff --git a/dvc/utils/conversions.py b/dvc/utils/conversions.py new file mode 100644 index 0000000000..5fe2950d6d --- /dev/null +++ b/dvc/utils/conversions.py @@ -0,0 +1,24 @@ +# https://github.com/aws/aws-cli/blob/5aa599949f60b6af554fd5714d7161aa272716f7/awscli/customizations/s3/utils.py + +MULTIPLIERS = { + "kb": 1024, + "mb": 1024 ** 2, + "gb": 1024 ** 3, + "tb": 1024 ** 4, + "kib": 1024, + "mib": 1024 ** 2, + "gib": 1024 ** 3, + "tib": 1024 ** 4, +} + + +def human_readable_to_bytes(value): + value = value.lower() + suffix = None + if value.endswith(tuple(MULTIPLIERS.keys())): + size = 2 + size += value[-2] == "i" # KiB, MiB etc + value, suffix = value[:-size], value[-size:] + + multiplier = MULTIPLIERS.get(suffix, 1) + return int(value) * multiplier diff --git a/tests/func/test_s3.py b/tests/func/test_s3.py index 8f6f4c680c..b148be84f5 100644 --- a/tests/func/test_s3.py +++ b/tests/func/test_s3.py @@ -1,3 +1,6 @@ +import importlib +import sys +import textwrap from functools import wraps import boto3 @@ -130,3 +133,82 @@ def test_s3_upload_fobj(tmp_dir, dvc, s3): tree.upload_fobj(stream, to_info, 1) assert to_info.read_text() == "foo" + + +KB = 1024 +MB = KB ** 2 +GB = KB ** 3 + + +def test_s3_aws_config(tmp_dir, dvc, s3, monkeypatch): + config_directory = tmp_dir / ".aws" + config_directory.mkdir() + (config_directory / "config").write_text( + textwrap.dedent( + """\ + [default] + s3 = + max_concurrent_requests = 20000 + max_queue_size = 1000 + multipart_threshold = 1000KiB + multipart_chunksize = 64MB + use_accelerate_endpoint = true + addressing_style = path + """ + ) + ) + + if sys.platform == "win32": + var = "USERPROFILE" + else: + var = "HOME" + monkeypatch.setenv(var, str(tmp_dir)) + + # Fresh import to see the effects of changing HOME variable + s3_mod = importlib.reload(sys.modules[S3Tree.__module__]) + tree = s3_mod.S3Tree(dvc, s3.config) + assert tree._transfer_config is None + + with tree._get_s3() as s3: + s3_config = s3.meta.client.meta.config.s3 + assert s3_config["use_accelerate_endpoint"] + assert s3_config["addressing_style"] == "path" + + transfer_config = tree._transfer_config + assert transfer_config.max_io_queue_size == 1000 + assert transfer_config.multipart_chunksize == 64 * MB + assert transfer_config.multipart_threshold == 1000 * KB + assert transfer_config.max_request_concurrency == 20000 + + +def test_s3_aws_config_different_profile(tmp_dir, dvc, s3, monkeypatch): + config_file = tmp_dir / "aws_config.ini" + config_file.write_text( + textwrap.dedent( + """\ + [default] + extra = keys + s3 = + addressing_style = auto + use_accelerate_endpoint = true + multipart_threshold = ThisIsNotGoingToBeCasted! + [profile dev] + some_extra = keys + s3 = + addresing_style = virtual + multipart_threshold = 2GiB + """ + ) + ) + monkeypatch.setenv("AWS_CONFIG_FILE", config_file) + + tree = S3Tree(dvc, {**s3.config, "profile": "dev"}) + assert tree._transfer_config is None + + with tree._get_s3() as s3: + s3_config = s3.meta.client.meta.config.s3 + assert s3_config["addresing_style"] == "virtual" + assert "use_accelerate_endpoint" not in s3_config + + transfer_config = tree._transfer_config + assert transfer_config.multipart_threshold == 2 * GB diff --git a/tests/unit/utils/test_conversions.py b/tests/unit/utils/test_conversions.py new file mode 100644 index 0000000000..063dd45565 --- /dev/null +++ b/tests/unit/utils/test_conversions.py @@ -0,0 +1,30 @@ +import pytest + +from dvc.utils.conversions import human_readable_to_bytes + +KB = 1024 +MB = KB ** 2 +GB = KB ** 3 +TB = KB ** 4 + + +@pytest.mark.parametrize( + "test_input, expected", + [ + ("10", 10), + ("10 ", 10), + ("1kb", 1 * KB), + ("2kb", 2 * KB), + ("1000mib", 1000 * MB), + ("20gB", 20 * GB), + ("10Tib", 10 * TB), + ], +) +def test_conversions_human_readable_to_bytes(test_input, expected): + assert human_readable_to_bytes(test_input) == expected + + +@pytest.mark.parametrize("invalid_input", ["foo", "10XB", "1000Pb", "fooMiB"]) +def test_conversions_human_readable_to_bytes_invalid(invalid_input): + with pytest.raises(ValueError): + human_readable_to_bytes(invalid_input)