Skip to content

Commit

Permalink
config: Support ~/.aws/config parsing (#5378)
Browse files Browse the repository at this point in the history
* config: Support ~/.aws/config parsing

* skip when the ~/.aws/config doesn't exist

* Expand windows paths properly
  • Loading branch information
isidentical authored Feb 4, 2021
1 parent 21200c1 commit e7cc776
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 5 deletions.
1 change: 1 addition & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class RelPath(str):
"region": str,
"profile": str,
"credentialpath": str,
"configpath": str,
"endpointurl": str,
"access_key_id": str,
"secret_access_key": str,
Expand Down
70 changes: 65 additions & 5 deletions dvc/tree/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
from dvc.utils import error_link
from dvc.utils import conversions, error_link

from .base import BaseTree

logger = logging.getLogger(__name__)

_AWS_CONFIG_PATH = os.path.join(os.path.expanduser("~"), ".aws", "config")


class S3Tree(BaseTree):
scheme = Schemes.S3
Expand Down Expand Up @@ -61,6 +63,56 @@ def __init__(self, repo, config):
if shared_creds:
os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds)

config_path = config.get("configpath")
if config_path:
os.environ.setdefault("AWS_CONFIG_FILE", config_path)
self._transfer_config = None

# https://github.com/aws/aws-cli/blob/0376c6262d6b15dc36c82e6da6e1aad10249cc8c/awscli/customizations/s3/transferconfig.py#L107-L113
_TRANSFER_CONFIG_ALIASES = {
"max_queue_size": "max_io_queue",
"max_concurrent_requests": "max_concurrency",
"multipart_threshold": "multipart_threshold",
"multipart_chunksize": "multipart_chunksize",
}

def _transform_config(self, s3_config):
"""Splits the general s3 config into 2 different config
objects, one for transfer.TransferConfig and other is the
general session config"""

config, transfer_config = {}, {}
for key, value in s3_config.items():
if key in self._TRANSFER_CONFIG_ALIASES:
if key in {"multipart_chunksize", "multipart_threshold"}:
# cast human readable sizes (like 24MiB) to integers
value = conversions.human_readable_to_bytes(value)
else:
value = int(value)
transfer_config[self._TRANSFER_CONFIG_ALIASES[key]] = value
else:
config[key] = value

return config, transfer_config

def _process_config(self):
from boto3.s3.transfer import TransferConfig
from botocore.configloader import load_config

config_path = os.environ.get("AWS_CONFIG_FILE", _AWS_CONFIG_PATH)
if not os.path.exists(config_path):
return None

config = load_config(config_path)
profile = config["profiles"].get(self.profile or "default")
if not profile:
return None

s3_config = profile.get("s3", {})
s3_config, transfer_config = self._transform_config(s3_config)
self._transfer_config = TransferConfig(**transfer_config)
return s3_config

@wrap_prop(threading.Lock())
@cached_property
def s3(self):
Expand All @@ -79,12 +131,15 @@ def s3(self):
session_opts["aws_session_token"] = self.session_token

session = boto3.session.Session(**session_opts)
s3_config = self._process_config()

return session.resource(
"s3",
endpoint_url=self.endpoint_url,
use_ssl=self.use_ssl,
config=boto3.session.Config(signature_version="s3v4"),
config=boto3.session.Config(
signature_version="s3v4", s3=s3_config
),
)

@contextmanager
Expand Down Expand Up @@ -373,7 +428,7 @@ def get_file_hash(self, path_info):

def _upload_fobj(self, fobj, to_info):
with self._get_obj(to_info) as obj:
obj.upload_fileobj(fobj)
obj.upload_fileobj(fobj, Config=self._transfer_config)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand All @@ -384,7 +439,10 @@ def _upload(
disable=no_progress_bar, total=total, bytes=True, desc=name
) as pbar:
obj.upload_file(
from_file, Callback=pbar.update, ExtraArgs=self.extra_args,
from_file,
Callback=pbar.update,
ExtraArgs=self.extra_args,
Config=self._transfer_config,
)

def _download(self, from_info, to_file, name=None, no_progress_bar=False):
Expand All @@ -395,4 +453,6 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False):
bytes=True,
desc=name,
) as pbar:
obj.download_file(to_file, Callback=pbar.update)
obj.download_file(
to_file, Callback=pbar.update, Config=self._transfer_config
)
24 changes: 24 additions & 0 deletions dvc/utils/conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# https://github.com/aws/aws-cli/blob/5aa599949f60b6af554fd5714d7161aa272716f7/awscli/customizations/s3/utils.py

MULTIPLIERS = {
"kb": 1024,
"mb": 1024 ** 2,
"gb": 1024 ** 3,
"tb": 1024 ** 4,
"kib": 1024,
"mib": 1024 ** 2,
"gib": 1024 ** 3,
"tib": 1024 ** 4,
}


def human_readable_to_bytes(value):
value = value.lower()
suffix = None
if value.endswith(tuple(MULTIPLIERS.keys())):
size = 2
size += value[-2] == "i" # KiB, MiB etc
value, suffix = value[:-size], value[-size:]

multiplier = MULTIPLIERS.get(suffix, 1)
return int(value) * multiplier
82 changes: 82 additions & 0 deletions tests/func/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import importlib
import sys
import textwrap
from functools import wraps

import boto3
Expand Down Expand Up @@ -130,3 +133,82 @@ def test_s3_upload_fobj(tmp_dir, dvc, s3):
tree.upload_fobj(stream, to_info, 1)

assert to_info.read_text() == "foo"


KB = 1024
MB = KB ** 2
GB = KB ** 3


def test_s3_aws_config(tmp_dir, dvc, s3, monkeypatch):
config_directory = tmp_dir / ".aws"
config_directory.mkdir()
(config_directory / "config").write_text(
textwrap.dedent(
"""\
[default]
s3 =
max_concurrent_requests = 20000
max_queue_size = 1000
multipart_threshold = 1000KiB
multipart_chunksize = 64MB
use_accelerate_endpoint = true
addressing_style = path
"""
)
)

if sys.platform == "win32":
var = "USERPROFILE"
else:
var = "HOME"
monkeypatch.setenv(var, str(tmp_dir))

# Fresh import to see the effects of changing HOME variable
s3_mod = importlib.reload(sys.modules[S3Tree.__module__])
tree = s3_mod.S3Tree(dvc, s3.config)
assert tree._transfer_config is None

with tree._get_s3() as s3:
s3_config = s3.meta.client.meta.config.s3
assert s3_config["use_accelerate_endpoint"]
assert s3_config["addressing_style"] == "path"

transfer_config = tree._transfer_config
assert transfer_config.max_io_queue_size == 1000
assert transfer_config.multipart_chunksize == 64 * MB
assert transfer_config.multipart_threshold == 1000 * KB
assert transfer_config.max_request_concurrency == 20000


def test_s3_aws_config_different_profile(tmp_dir, dvc, s3, monkeypatch):
config_file = tmp_dir / "aws_config.ini"
config_file.write_text(
textwrap.dedent(
"""\
[default]
extra = keys
s3 =
addressing_style = auto
use_accelerate_endpoint = true
multipart_threshold = ThisIsNotGoingToBeCasted!
[profile dev]
some_extra = keys
s3 =
addresing_style = virtual
multipart_threshold = 2GiB
"""
)
)
monkeypatch.setenv("AWS_CONFIG_FILE", config_file)

tree = S3Tree(dvc, {**s3.config, "profile": "dev"})
assert tree._transfer_config is None

with tree._get_s3() as s3:
s3_config = s3.meta.client.meta.config.s3
assert s3_config["addresing_style"] == "virtual"
assert "use_accelerate_endpoint" not in s3_config

transfer_config = tree._transfer_config
assert transfer_config.multipart_threshold == 2 * GB
30 changes: 30 additions & 0 deletions tests/unit/utils/test_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from dvc.utils.conversions import human_readable_to_bytes

KB = 1024
MB = KB ** 2
GB = KB ** 3
TB = KB ** 4


@pytest.mark.parametrize(
"test_input, expected",
[
("10", 10),
("10 ", 10),
("1kb", 1 * KB),
("2kb", 2 * KB),
("1000mib", 1000 * MB),
("20gB", 20 * GB),
("10Tib", 10 * TB),
],
)
def test_conversions_human_readable_to_bytes(test_input, expected):
assert human_readable_to_bytes(test_input) == expected


@pytest.mark.parametrize("invalid_input", ["foo", "10XB", "1000Pb", "fooMiB"])
def test_conversions_human_readable_to_bytes_invalid(invalid_input):
with pytest.raises(ValueError):
human_readable_to_bytes(invalid_input)

0 comments on commit e7cc776

Please sign in to comment.