Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config: Support ~/.aws/config parsing #5378

Merged
merged 3 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class RelPath(str):
"region": str,
"profile": str,
"credentialpath": str,
"configpath": str,
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
"endpointurl": str,
"access_key_id": str,
"secret_access_key": str,
Expand Down
70 changes: 65 additions & 5 deletions dvc/tree/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.scheme import Schemes
from dvc.utils import error_link
from dvc.utils import conversions, error_link

from .base import BaseTree

logger = logging.getLogger(__name__)

_AWS_CONFIG_PATH = os.path.join(os.path.expanduser("~"), ".aws", "config")


class S3Tree(BaseTree):
scheme = Schemes.S3
Expand Down Expand Up @@ -60,6 +62,56 @@ def __init__(self, repo, config):
if shared_creds:
os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds)

config_path = config.get("configpath")
if config_path:
os.environ.setdefault("AWS_CONFIG_FILE", config_path)
self._transfer_config = None

# https://github.com/aws/aws-cli/blob/0376c6262d6b15dc36c82e6da6e1aad10249cc8c/awscli/customizations/s3/transferconfig.py#L107-L113
_TRANSFER_CONFIG_ALIASES = {
"max_queue_size": "max_io_queue",
"max_concurrent_requests": "max_concurrency",
"multipart_threshold": "multipart_threshold",
"multipart_chunksize": "multipart_chunksize",
}

def _transform_config(self, s3_config):
"""Splits the general s3 config into 2 different config
objects, one for transfer.TransferConfig and other is the
general session config"""

config, transfer_config = {}, {}
for key, value in s3_config.items():
if key in self._TRANSFER_CONFIG_ALIASES:
if key in {"multipart_chunksize", "multipart_threshold"}:
# cast human readable sizes (like 24MiB) to integers
value = conversions.human_readable_to_bytes(value)
else:
value = int(value)
transfer_config[self._TRANSFER_CONFIG_ALIASES[key]] = value
else:
config[key] = value

return config, transfer_config

def _process_config(self):
from boto3.s3.transfer import TransferConfig
from botocore.configloader import load_config

config_path = os.environ.get("AWS_CONFIG_FILE", _AWS_CONFIG_PATH)
if not os.path.exists(config_path):
return None

config = load_config(config_path)
profile = config["profiles"].get(self.profile or "default")
if not profile:
return None

s3_config = profile.get("s3", {})
s3_config, transfer_config = self._transform_config(s3_config)
self._transfer_config = TransferConfig(**transfer_config)
return s3_config

@wrap_prop(threading.Lock())
@cached_property
def s3(self):
Expand All @@ -78,12 +130,15 @@ def s3(self):
session_opts["aws_session_token"] = self.session_token

session = boto3.session.Session(**session_opts)
s3_config = self._process_config()

return session.resource(
"s3",
endpoint_url=self.endpoint_url,
use_ssl=self.use_ssl,
config=boto3.session.Config(signature_version="s3v4"),
config=boto3.session.Config(
signature_version="s3v4", s3=s3_config
),
)

@contextmanager
Expand Down Expand Up @@ -355,7 +410,7 @@ def get_file_hash(self, path_info):

def _upload_fobj(self, fobj, to_info):
with self._get_obj(to_info) as obj:
obj.upload_fileobj(fobj)
obj.upload_fileobj(fobj, Config=self._transfer_config)

def _upload(
self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs
Expand All @@ -366,7 +421,10 @@ def _upload(
disable=no_progress_bar, total=total, bytes=True, desc=name
) as pbar:
obj.upload_file(
from_file, Callback=pbar.update, ExtraArgs=self.extra_args,
from_file,
Callback=pbar.update,
ExtraArgs=self.extra_args,
Config=self._transfer_config,
)

def _download(self, from_info, to_file, name=None, no_progress_bar=False):
Expand All @@ -377,4 +435,6 @@ def _download(self, from_info, to_file, name=None, no_progress_bar=False):
bytes=True,
desc=name,
) as pbar:
obj.download_file(to_file, Callback=pbar.update)
obj.download_file(
to_file, Callback=pbar.update, Config=self._transfer_config
)
24 changes: 24 additions & 0 deletions dvc/utils/conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# https://github.com/aws/aws-cli/blob/5aa599949f60b6af554fd5714d7161aa272716f7/awscli/customizations/s3/utils.py
Copy link
Contributor

@efiop efiop Feb 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this whole file is awscli-specific, we could move this stuff to dvc/tree/s3.py or create something like dvc/tree/s3/utils.py and it put it there. This piece is quite small, so could go with the former, no problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping here is OK (even though this is a pre-emptive assumption) I believe this utility might come in handy in other places since it does a very general job. The awscli-specific part here is the handling of IEC suffixes, though it doesn't differ much and something we can ignore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's keep it then πŸ‘


MULTIPLIERS = {
"kb": 1024,
"mb": 1024 ** 2,
"gb": 1024 ** 3,
"tb": 1024 ** 4,
"kib": 1024,
"mib": 1024 ** 2,
"gib": 1024 ** 3,
"tib": 1024 ** 4,
}


def human_readable_to_bytes(value):
value = value.lower()
suffix = None
if value.endswith(tuple(MULTIPLIERS.keys())):
size = 2
size += value[-2] == "i" # KiB, MiB etc
value, suffix = value[:-size], value[-size:]

multiplier = MULTIPLIERS.get(suffix, 1)
return int(value) * multiplier
82 changes: 82 additions & 0 deletions tests/func/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import importlib
import sys
import textwrap
from functools import wraps

import boto3
Expand Down Expand Up @@ -130,3 +133,82 @@ def test_s3_upload_fobj(tmp_dir, dvc, s3):
tree.upload_fobj(stream, to_info, 1)

assert to_info.read_text() == "foo"


KB = 1024
MB = KB ** 2
GB = KB ** 3


def test_s3_aws_config(tmp_dir, dvc, s3, monkeypatch):
config_directory = tmp_dir / ".aws"
config_directory.mkdir()
(config_directory / "config").write_text(
textwrap.dedent(
"""\
[default]
s3 =
max_concurrent_requests = 20000
max_queue_size = 1000
multipart_threshold = 1000KiB
multipart_chunksize = 64MB
use_accelerate_endpoint = true
addressing_style = path
"""
)
)

if sys.platform == "win32":
var = "USERPROFILE"
else:
var = "HOME"
monkeypatch.setenv(var, str(tmp_dir))

# Fresh import to see the effects of changing HOME variable
s3_mod = importlib.reload(sys.modules[S3Tree.__module__])
tree = s3_mod.S3Tree(dvc, s3.config)
assert tree._transfer_config is None

with tree._get_s3() as s3:
s3_config = s3.meta.client.meta.config.s3
assert s3_config["use_accelerate_endpoint"]
assert s3_config["addressing_style"] == "path"

transfer_config = tree._transfer_config
assert transfer_config.max_io_queue_size == 1000
assert transfer_config.multipart_chunksize == 64 * MB
assert transfer_config.multipart_threshold == 1000 * KB
assert transfer_config.max_request_concurrency == 20000


def test_s3_aws_config_different_profile(tmp_dir, dvc, s3, monkeypatch):
config_file = tmp_dir / "aws_config.ini"
config_file.write_text(
textwrap.dedent(
"""\
[default]
extra = keys
s3 =
addressing_style = auto
use_accelerate_endpoint = true
multipart_threshold = ThisIsNotGoingToBeCasted!
[profile dev]
some_extra = keys
s3 =
addresing_style = virtual
multipart_threshold = 2GiB
"""
)
)
monkeypatch.setenv("AWS_CONFIG_FILE", config_file)

tree = S3Tree(dvc, {**s3.config, "profile": "dev"})
assert tree._transfer_config is None

with tree._get_s3() as s3:
s3_config = s3.meta.client.meta.config.s3
assert s3_config["addresing_style"] == "virtual"
assert "use_accelerate_endpoint" not in s3_config

transfer_config = tree._transfer_config
assert transfer_config.multipart_threshold == 2 * GB
30 changes: 30 additions & 0 deletions tests/unit/utils/test_conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from dvc.utils.conversions import human_readable_to_bytes

KB = 1024
MB = KB ** 2
GB = KB ** 3
TB = KB ** 4


@pytest.mark.parametrize(
"test_input, expected",
[
("10", 10),
("10 ", 10),
("1kb", 1 * KB),
("2kb", 2 * KB),
("1000mib", 1000 * MB),
("20gB", 20 * GB),
("10Tib", 10 * TB),
],
)
def test_conversions_human_readable_to_bytes(test_input, expected):
assert human_readable_to_bytes(test_input) == expected


@pytest.mark.parametrize("invalid_input", ["foo", "10XB", "1000Pb", "fooMiB"])
def test_conversions_human_readable_to_bytes_invalid(invalid_input):
with pytest.raises(ValueError):
human_readable_to_bytes(invalid_input)