From e3fbf6b16ca8d9992aa752b3d3d5babe041ad9ff Mon Sep 17 00:00:00 2001 From: Bikouo Aubin <79859644+abikouo@users.noreply.github.com> Date: Wed, 19 Oct 2022 09:46:37 +0200 Subject: [PATCH] move get_s3_connection, reduce complexity and increase coverage (#1139) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit move get_s3_connection, reduce complexity and increase coverage SUMMARY get_s3_connection is duplicated into modules s3_object and s3_object_info, the goal of this pull request is to move to a single place, reduce code complexity and increase coverage ISSUE TYPE Feature Pull Request COMPONENT NAME ADDITIONAL INFORMATION Reviewed-by: Mark Chappell Reviewed-by: Gonéri Le Bouder Reviewed-by: Bikouo Aubin --- .../module_utils_s3-unit-testing.yml | 3 + plugins/module_utils/s3.py | 164 +++++-- plugins/modules/s3_bucket.py | 4 +- plugins/modules/s3_object.py | 50 +- plugins/modules/s3_object_info.py | 50 +- tests/unit/module_utils/test_s3.py | 462 +++++++++++++++--- tests/unit/plugins/modules/test_s3_object.py | 29 -- 7 files changed, 529 insertions(+), 233 deletions(-) create mode 100644 changelogs/fragments/module_utils_s3-unit-testing.yml delete mode 100644 tests/unit/plugins/modules/test_s3_object.py diff --git a/changelogs/fragments/module_utils_s3-unit-testing.yml b/changelogs/fragments/module_utils_s3-unit-testing.yml new file mode 100644 index 00000000000..1a65c5e5011 --- /dev/null +++ b/changelogs/fragments/module_utils_s3-unit-testing.yml @@ -0,0 +1,3 @@ +--- +minor_changes: +- module_utils.s3 - Refactor get_s3_connection into a module_utils for S3 modules and expand module_utils.s3 unit tests (https://github.com/ansible-collections/amazon.aws/pull/1139). diff --git a/plugins/module_utils/s3.py b/plugins/module_utils/s3.py index c13c91f25b1..21afee8d75c 100644 --- a/plugins/module_utils/s3.py +++ b/plugins/module_utils/s3.py @@ -4,7 +4,13 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from ansible.module_utils.basic import to_text +from urllib.parse import urlparse + +from .botocore import boto3_conn + try: + from botocore.client import Config from botocore.exceptions import BotoCoreError, ClientError except ImportError: pass # Handled by the calling module @@ -22,6 +28,38 @@ import string +def s3_head_objects(client, parts, bucket, obj, versionId): + args = {"Bucket": bucket, "Key": obj} + if versionId: + args["VersionId"] = versionId + + for part in range(1, parts + 1): + args["PartNumber"] = part + yield client.head_object(**args) + + +def calculate_checksum_with_file(client, parts, bucket, obj, versionId, filename): + digests = [] + with open(filename, 'rb') as f: + for head in s3_head_objects(client, parts, bucket, obj, versionId): + digests.append(md5(f.read(int(head['ContentLength']))).digest()) + + digest_squared = b''.join(digests) + return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests)) + + +def calculate_checksum_with_content(client, parts, bucket, obj, versionId, content): + digests = [] + offset = 0 + for head in s3_head_objects(client, parts, bucket, obj, versionId): + length = int(head['ContentLength']) + digests.append(md5(content[offset:offset + length]).digest()) + offset += length + + digest_squared = b''.join(digests) + return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests)) + + def calculate_etag(module, filename, etag, s3, bucket, obj, version=None): if not HAS_MD5: return None @@ -29,26 +67,10 @@ def calculate_etag(module, filename, etag, s3, bucket, obj, version=None): if '-' in etag: # Multi-part ETag; a hash of the hashes of each part. parts = int(etag[1:-1].split('-')[1]) - digests = [] - - s3_kwargs = dict( - Bucket=bucket, - Key=obj, - ) - if version: - s3_kwargs['VersionId'] = version - - with open(filename, 'rb') as f: - for part_num in range(1, parts + 1): - s3_kwargs['PartNumber'] = part_num - try: - head = s3.head_object(**s3_kwargs) - except (BotoCoreError, ClientError) as e: - module.fail_json_aws(e, msg="Failed to get head object") - digests.append(md5(f.read(int(head['ContentLength'])))) - - digest_squared = md5(b''.join(m.digest() for m in digests)) - return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests)) + try: + return calculate_checksum_with_file(s3, parts, bucket, obj, version, filename) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get head object") else: # Compute the MD5 sum normally return '"{0}"'.format(module.md5(filename)) @@ -60,43 +82,89 @@ def calculate_etag_content(module, content, etag, s3, bucket, obj, version=None) if '-' in etag: # Multi-part ETag; a hash of the hashes of each part. parts = int(etag[1:-1].split('-')[1]) - digests = [] - offset = 0 - - s3_kwargs = dict( - Bucket=bucket, - Key=obj, - ) - if version: - s3_kwargs['VersionId'] = version - - for part_num in range(1, parts + 1): - s3_kwargs['PartNumber'] = part_num - try: - head = s3.head_object(**s3_kwargs) - except (BotoCoreError, ClientError) as e: - module.fail_json_aws(e, msg="Failed to get head object") - length = int(head['ContentLength']) - digests.append(md5(content[offset:offset + length])) - offset += length - - digest_squared = md5(b''.join(m.digest() for m in digests)) - return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests)) + try: + return calculate_checksum_with_content(s3, parts, bucket, obj, version, content) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get head object") else: # Compute the MD5 sum normally return '"{0}"'.format(md5(content).hexdigest()) -def validate_bucket_name(module, name): +def validate_bucket_name(name): # See: https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html if len(name) < 3: - module.fail_json(msg='the length of an S3 bucket must be at least 3 characters') + return 'the length of an S3 bucket must be at least 3 characters' if len(name) > 63: - module.fail_json(msg='the length of an S3 bucket cannot exceed 63 characters') + return 'the length of an S3 bucket cannot exceed 63 characters' legal_characters = string.ascii_lowercase + ".-" + string.digits illegal_characters = [c for c in name if c not in legal_characters] if illegal_characters: - module.fail_json(msg='invalid character(s) found in the bucket name') + return 'invalid character(s) found in the bucket name' if name[-1] not in string.ascii_lowercase + string.digits: - module.fail_json(msg='bucket names must begin and end with a letter or number') - return True + return 'bucket names must begin and end with a letter or number' + return None + + +# Spot special case of fakes3. +def is_fakes3(url): + """ Return True if endpoint_url has scheme fakes3:// """ + result = False + if url is not None: + result = urlparse(url).scheme in ('fakes3', 'fakes3s') + return result + + +def parse_fakes3_endpoint(url): + fakes3 = urlparse(url) + protocol = "http" + port = fakes3.port or 80 + if fakes3.scheme == 'fakes3s': + protocol = "https" + port = fakes3.port or 443 + endpoint_url = f"{protocol}://{fakes3.hostname}:{to_text(port)}" + use_ssl = bool(fakes3.scheme == 'fakes3s') + return {"endpoint": endpoint_url, "use_ssl": use_ssl} + + +def parse_ceph_endpoint(url): + ceph = urlparse(url) + use_ssl = bool(ceph.scheme == 'https') + return {"endpoint": url, "use_ssl": use_ssl} + + +def parse_default_endpoint(url, mode, encryption_mode, dualstack, sig_4): + result = {"endpoint": url} + config = {} + if (mode in ('get', 'getstr') and sig_4) or (mode == "put" and encryption_mode == "aws:kms"): + config["signature_version"] = "s3v4" + if dualstack: + config["s3"] = {"use_dualstack_endpoint": True} + if config != {}: + result["config"] = Config(**config) + return result + + +def s3_conn_params(mode, encryption_mode, dualstack, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False): + params = {"conn_type": "client", "resource": "s3", "region": location, **aws_connect_kwargs} + if ceph: + endpoint_p = parse_ceph_endpoint(endpoint_url) + elif is_fakes3(endpoint_url): + endpoint_p = parse_fakes3_endpoint(endpoint_url) + else: + endpoint_p = parse_default_endpoint(endpoint_url, mode, encryption_mode, dualstack, sig_4) + + params.update(endpoint_p) + return params + + +def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False): + s3_conn = s3_conn_params(module.params.get("mode"), + module.params.get("encryption_mode"), + module.params.get("dualstack"), + aws_connect_kwargs, + location, + ceph, + endpoint_url, + sig_4) + return boto3_conn(module, **s3_conn) diff --git a/plugins/modules/s3_bucket.py b/plugins/modules/s3_bucket.py index 8a09858c39b..f4860900b5b 100644 --- a/plugins/modules/s3_bucket.py +++ b/plugins/modules/s3_bucket.py @@ -1132,7 +1132,9 @@ def main(): region, _ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True) if module.params.get('validate_bucket_name'): - validate_bucket_name(module, module.params["name"]) + err = validate_bucket_name(module.params["name"]) + if err: + module.fail_json(msg=err) if region in ('us-east-1', '', None): # default to US Standard region diff --git a/plugins/modules/s3_object.py b/plugins/modules/s3_object.py index 1bc379846ff..02ffaeac041 100644 --- a/plugins/modules/s3_object.py +++ b/plugins/modules/s3_object.py @@ -405,15 +405,13 @@ except ImportError: pass # Handled by AnsibleAWSModule -from ansible.module_utils.basic import to_text from ansible.module_utils.basic import to_native -from ansible.module_utils.six.moves.urllib.parse import urlparse from ansible_collections.amazon.aws.plugins.module_utils.core import AnsibleAWSModule from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_message from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry -from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_conn +from ansible_collections.amazon.aws.plugins.module_utils.s3 import get_s3_connection from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info from ansible_collections.amazon.aws.plugins.module_utils.ec2 import ansible_dict_to_boto3_tag_list from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_tag_list_to_ansible_dict @@ -845,48 +843,6 @@ def copy_object_to_bucket(module, s3, bucket, obj, encrypt, metadata, validate, module.fail_json_aws(e, msg="Failed while copying object %s from bucket %s." % (obj, module.params['copy_src'].get('Bucket'))) -def is_fakes3(endpoint_url): - """ Return True if endpoint_url has scheme fakes3:// """ - if endpoint_url is not None: - return urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s') - else: - return False - - -def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False): - if ceph: # TODO - test this - ceph = urlparse(endpoint_url) - params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https', - region=location, endpoint=endpoint_url, **aws_connect_kwargs) - elif is_fakes3(endpoint_url): - fakes3 = urlparse(endpoint_url) - port = fakes3.port - if fakes3.scheme == 'fakes3s': - protocol = "https" - if port is None: - port = 443 - else: - protocol = "http" - if port is None: - port = 80 - params = dict(module=module, conn_type='client', resource='s3', region=location, - endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)), - use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs) - else: - params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=endpoint_url, **aws_connect_kwargs) - if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms': - params['config'] = botocore.client.Config(signature_version='s3v4') - elif module.params['mode'] in ('get', 'getstr', 'geturl') and sig_4: - params['config'] = botocore.client.Config(signature_version='s3v4') - if module.params['dualstack']: - dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True}) - if 'config' in params: - params['config'] = params['config'].merge(dualconf) - else: - params['config'] = dualconf - return boto3_conn(**params) - - def get_current_object_tags_dict(s3, bucket, obj, version=None): try: if version: @@ -1040,7 +996,9 @@ def main(): bucket_canned_acl = ["private", "public-read", "public-read-write", "authenticated-read"] if module.params.get('validate_bucket_name'): - validate_bucket_name(module, bucket) + err = validate_bucket_name(bucket) + if err: + module.fail_json(msg=err) if overwrite not in ['always', 'never', 'different', 'latest']: if module.boolean(overwrite): diff --git a/plugins/modules/s3_object_info.py b/plugins/modules/s3_object_info.py index 88e66dc4f05..5a98831fcac 100644 --- a/plugins/modules/s3_object_info.py +++ b/plugins/modules/s3_object_info.py @@ -440,16 +440,13 @@ except ImportError: pass # Handled by AnsibleAWSModule -from ansible.module_utils.basic import to_text -from ansible.module_utils.six.moves.urllib.parse import urlparse - from ansible_collections.amazon.aws.plugins.module_utils.core import AnsibleAWSModule from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry from ansible_collections.amazon.aws.plugins.module_utils.ec2 import camel_dict_to_snake_dict from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_tag_list_to_ansible_dict from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info -from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_conn +from ansible_collections.amazon.aws.plugins.module_utils.s3 import get_s3_connection def describe_s3_object_acl(connection, bucket_name, object_name): @@ -670,49 +667,6 @@ def object_check(connection, module, bucket_name, object_name): module.fail_json_aws(e, msg="The object %s does not exist or is missing access permissions." % object_name) -# To get S3 connection, in case of dealing with ceph, dualstack, etc. -def is_fakes3(endpoint_url): - """ Return True if endpoint_url has scheme fakes3:// """ - if endpoint_url is not None: - return urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s') - else: - return False - - -def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False): - if ceph: # TODO - test this - ceph = urlparse(endpoint_url) - params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https', - region=location, endpoint=endpoint_url, **aws_connect_kwargs) - elif is_fakes3(endpoint_url): - fakes3 = urlparse(endpoint_url) - port = fakes3.port - if fakes3.scheme == 'fakes3s': - protocol = "https" - if port is None: - port = 443 - else: - protocol = "http" - if port is None: - port = 80 - params = dict(module=module, conn_type='client', resource='s3', region=location, - endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)), - use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs) - else: - params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=endpoint_url, **aws_connect_kwargs) - if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms': - params['config'] = botocore.client.Config(signature_version='s3v4') - elif module.params['mode'] in ('get', 'getstr') and sig_4: - params['config'] = botocore.client.Config(signature_version='s3v4') - if module.params['dualstack']: - dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True}) - if 'config' in params: - params['config'] = params['config'].merge(dualconf) - else: - params['config'] = dualconf - return boto3_conn(**params) - - def main(): argument_spec = dict( @@ -730,7 +684,7 @@ def main(): ), bucket_name=dict(required=True, type='str'), object_name=dict(type='str'), - dualstack=dict(default='no', type='bool'), + dualstack=dict(default=False, type='bool'), ceph=dict(default=False, type='bool', aliases=['rgw']), ) diff --git a/tests/unit/module_utils/test_s3.py b/tests/unit/module_utils/test_s3.py index cd73a6c2e43..8434fa932f4 100644 --- a/tests/unit/module_utils/test_s3.py +++ b/tests/unit/module_utils/test_s3.py @@ -5,80 +5,420 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import (absolute_import, division, print_function) - __metaclass__ = type -from ansible_collections.amazon.aws.tests.unit.compat.mock import MagicMock +import pytest +import random +import string + from ansible_collections.amazon.aws.plugins.module_utils import s3 -from ansible.module_utils.basic import AnsibleModule +from unittest.mock import MagicMock, patch, call + +try: + import botocore +except ImportError: + pass + + +def generate_random_string(size, include_digits=True): + buffer = string.ascii_lowercase + if include_digits: + buffer += string.digits + + return ''.join(random.choice(buffer) for i in range(size)) + + +@pytest.mark.parametrize("parts", range(0, 10, 3)) +@pytest.mark.parametrize("version", [True, False]) +def test_s3_head_objects(parts, version): + + client = MagicMock() + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + versionId = None + if version: + versionId = random.randint(0, 1000) + + total = 0 + for head in s3.s3_head_objects(client, parts, s3bucket_name, s3bucket_object, versionId): + assert head == client.head_object.return_value + total += 1 + + assert total == parts + params = {"Bucket": s3bucket_name, "Key": s3bucket_object} + if versionId: + params["VersionId"] = versionId + + api_calls = [call(PartNumber=i, **params) for i in range(1, parts + 1)] + client.head_object.assert_has_calls(api_calls, any_order=True) + + +def raise_botoclient_exception(): + params = { + 'Error': { + 'Code': 1, + 'Message': 'Something went wrong' + }, + 'ResponseMetadata': { + 'RequestId': '01234567-89ab-cdef-0123-456789abcdef' + } + } + return botocore.exceptions.ClientError(params, 'some_called_method') + + +@pytest.mark.parametrize("use_file", [False, True]) +@pytest.mark.parametrize("parts", range(0, 10, 3)) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.md5") +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.s3_head_objects") +def test_calculate_checksum(m_s3_head_objects, m_s3_md5, use_file, parts, tmp_path): + + client = MagicMock() + mock_md5 = m_s3_md5.return_value + + mock_md5.digest.return_value = b"1" + mock_md5.hexdigest.return_value = ''.join(["f" for i in range(32)]) + m_s3_head_objects.return_value = [ + {"ContentLength": "%d" % (i + 1)} for i in range(parts) + ] -class FakeAnsibleModule(AnsibleModule): - def __init__(self): - pass + content = b'"f20e84ac3d0c33cea77b3f29e3323a09"' + test_function = s3.calculate_checksum_with_content + if use_file: + test_function = s3.calculate_checksum_with_file + test_dir = tmp_path / "test_s3" + test_dir.mkdir() + etag_file = test_dir / "etag.bin" + etag_file.write_bytes(content) + content = str(etag_file) + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + + result = test_function(client, parts, s3bucket_name, s3bucket_object, version, content) + + expected = '"{0}-{1}"'.format(mock_md5.hexdigest.return_value, parts) + assert result == expected + + mock_md5.digest.assert_has_calls([call() for i in range(parts)]) + mock_md5.hexdigest.assert_called_once() + + m_s3_head_objects.assert_called_once_with(client, parts, s3bucket_name, s3bucket_object, version) + + +@pytest.mark.parametrize("etag_multipart", [True, False]) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_file") +def test_calculate_etag(m_checksum_file, etag_multipart): + + module = MagicMock() + client = MagicMock() + + module.fail_json_aws.side_effect = SystemExit(2) + module.md5.return_value = generate_random_string(32) + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + parts = 3 + + etag = '"f20e84ac3d0c33cea77b3f29e3323a09"' + digest = '"9aa254f7f76fd14435b21e9448525b99"' + + file_name = generate_random_string(32) + + if not etag_multipart: + result = s3.calculate_etag(module, file_name, etag, client, s3bucket_name, s3bucket_object, version) + assert result == '"{0}"'.format(module.md5.return_value) + module.md5.assert_called_once_with(file_name) + else: + etag = '"f20e84ac3d0c33cea77b3f29e3323a09-{0}"'.format(parts) + m_checksum_file.return_value = digest + assert digest == s3.calculate_etag(module, file_name, etag, client, s3bucket_name, s3bucket_object, version) + + m_checksum_file.assert_called_with( + client, parts, s3bucket_name, s3bucket_object, version, file_name + ) + + +@pytest.mark.parametrize("etag_multipart", [True, False]) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_content") +def test_calculate_etag_content(m_checksum_content, etag_multipart): + + module = MagicMock() + client = MagicMock() -def test_calculate_etag_single_part(tmp_path_factory): - module = FakeAnsibleModule() - my_image = tmp_path_factory.mktemp("data") / "my.txt" - my_image.write_text("Hello World!") + module.fail_json_aws.side_effect = SystemExit(2) - etag = s3.calculate_etag( - module, str(my_image), etag="", s3=None, bucket=None, obj=None - ) - assert etag == '"ed076287532e86365e841e92bfc50d8c"' + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + parts = 3 + etag = '"f20e84ac3d0c33cea77b3f29e3323a09"' + content = b'"f20e84ac3d0c33cea77b3f29e3323a09"' + digest = '"9aa254f7f76fd14435b21e9448525b99"' -def test_calculate_etag_multi_part(tmp_path_factory): - module = FakeAnsibleModule() - my_image = tmp_path_factory.mktemp("data") / "my.txt" - my_image.write_text("Hello World!" * 1000) + if not etag_multipart: + assert digest == s3.calculate_etag_content(module, content, etag, client, s3bucket_name, s3bucket_object, version) + else: + etag = '"f20e84ac3d0c33cea77b3f29e3323a09-{0}"'.format(parts) + m_checksum_content.return_value = digest + result = s3.calculate_etag_content(module, content, etag, client, s3bucket_name, s3bucket_object, version) + assert result == digest - mocked_s3 = MagicMock() - mocked_s3.head_object.side_effect = [{"ContentLength": "1000"} for _i in range(12)] + m_checksum_content.assert_called_with( + client, parts, s3bucket_name, s3bucket_object, version, content + ) - etag = s3.calculate_etag( - module, - str(my_image), - etag='"f20e84ac3d0c33cea77b3f29e3323a09-12"', - s3=mocked_s3, - bucket="my-bucket", - obj="my-obj", - ) - assert etag == '"f20e84ac3d0c33cea77b3f29e3323a09-12"' - mocked_s3.head_object.assert_called_with( - Bucket="my-bucket", Key="my-obj", PartNumber=12 - ) +@pytest.mark.parametrize("using_file", [True, False]) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_content") +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_file") +def test_calculate_etag_failure(m_checksum_file, m_checksum_content, using_file): -def test_validate_bucket_name(): module = MagicMock() + client = MagicMock() + + module.fail_json_aws.side_effect = SystemExit(2) + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + parts = 3 + + etag = '"f20e84ac3d0c33cea77b3f29e3323a09-{0}"'.format(parts) + content = "some content or file name" + + if using_file: + test_method = s3.calculate_etag + m_checksum_file.side_effect = raise_botoclient_exception() + else: + test_method = s3.calculate_etag_content + m_checksum_content.side_effect = raise_botoclient_exception() + + with pytest.raises(SystemExit): + test_method(module, content, etag, client, s3bucket_name, s3bucket_object, version) + module.fail_json_aws.assert_called() + + +@pytest.mark.parametrize( + "bucket_name,result", + [ + ("docexamplebucket1", None), + ("log-delivery-march-2020", None), + ("my-hosted-content", None), + ("docexamplewebsite.com", None), + ("www.docexamplewebsite.com", None), + ("my.example.s3.bucket", None), + ("doc", None), + ("doc_example_bucket", "invalid character(s) found in the bucket name"), + ("DocExampleBucket", "invalid character(s) found in the bucket name"), + ("doc-example-bucket-", "bucket names must begin and end with a letter or number"), + ( + "this.string.has.more.than.63.characters.so.it.should.not.passed.the.validated", + "the length of an S3 bucket cannot exceed 63 characters" + ), + ("my", "the length of an S3 bucket must be at least 3 characters") + ] +) +def test_validate_bucket_name(bucket_name, result): + + assert result == s3.validate_bucket_name(bucket_name) + + +mod_urlparse = "ansible_collections.amazon.aws.plugins.module_utils.s3.urlparse" + + +class UrlInfo(object): + + def __init__(self, scheme=None, hostname=None, port=None): + self.hostname = hostname + self.scheme = scheme + self.port = port + + +@patch(mod_urlparse) +def test_is_fakes3_with_none_arg(m_urlparse): + m_urlparse.side_effect = SystemExit(1) + result = s3.is_fakes3(None) + assert not result + m_urlparse.assert_not_called() + + +@pytest.mark.parametrize( + "url,scheme,result", + [ + ("https://test-s3.amazon.com", "https", False), + ("fakes3://test-s3.amazon.com", "fakes3", True), + ("fakes3s://test-s3.amazon.com", "fakes3s", True), + ] +) +@patch(mod_urlparse) +def test_is_fakes3(m_urlparse, url, scheme, result): + m_urlparse.return_value = UrlInfo(scheme=scheme) + assert result == s3.is_fakes3(url) + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize( + "url,urlinfo,endpoint", + [ + ( + "fakes3://test-s3.amazon.com", + { + "scheme": "fakes3", + "hostname": "test-s3.amazon.com" + }, + { + "endpoint": "http://test-s3.amazon.com:80", + "use_ssl": False + } + ), + ( + "fakes3://test-s3.amazon.com:8080", + { + "scheme": "fakes3", + "hostname": "test-s3.amazon.com", + "port": 8080 + }, + { + "endpoint": "http://test-s3.amazon.com:8080", + "use_ssl": False + } + ), + ( + "fakes3s://test-s3.amazon.com", + { + "scheme": "fakes3s", + "hostname": "test-s3.amazon.com" + }, + { + "endpoint": "https://test-s3.amazon.com:443", + "use_ssl": True + } + ), + ( + "fakes3s://test-s3.amazon.com:9096", + { + "scheme": "fakes3s", + "hostname": "test-s3.amazon.com", + "port": 9096 + }, + { + "endpoint": "https://test-s3.amazon.com:9096", + "use_ssl": True + } + ) + ] +) +@patch(mod_urlparse) +def test_parse_fakes3_endpoint(m_urlparse, url, urlinfo, endpoint): + m_urlparse.return_value = UrlInfo(**urlinfo) + result = s3.parse_fakes3_endpoint(url) + assert endpoint == result + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize( + "url,scheme,use_ssl", + [ + ("https://test-s3-ceph.amazon.com", "https", True), + ("http://test-s3-ceph.amazon.com", "http", False), + ] +) +@patch(mod_urlparse) +def test_parse_ceph_endpoint(m_urlparse, url, scheme, use_ssl): + m_urlparse.return_value = UrlInfo(scheme=scheme) + result = s3.parse_ceph_endpoint(url) + assert result == {"endpoint": url, "use_ssl": use_ssl} + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize("mode", ["put", "get", "getstr"]) +@pytest.mark.parametrize("encryption_mode", ["aws:kms", "aws:unknown"]) +@pytest.mark.parametrize("dualstack", [True, False]) +@pytest.mark.parametrize("sig_4", [True, False]) +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.Config') +def test_parse_default_endpoint(m_config, mode, encryption_mode, dualstack, sig_4): + + url = "https://my-bucket.s3.us-west-2.amazonaws.com" + + signature_version = False + if mode in ('get', 'getstr') and sig_4: + signature_version = True + if mode == "put" and encryption_mode == "aws:kms": + signature_version = True + + attributes = {} + if signature_version: + attributes["signature_version"] = "s3v4" + if dualstack: + attributes["s3"] = {"use_dualstack_endpoint": True} + + result = s3.parse_default_endpoint(url, mode, encryption_mode, dualstack, sig_4) + + expected = {"endpoint": url} + if attributes: + m_config.assert_called_with(**attributes) + expected["config"] = m_config.return_value + else: + m_config.assert_not_called() + + assert result == expected + + +@pytest.mark.parametrize("scenario", ["ceph", "fakes3", "default"]) +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_default_endpoint') +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_fakes3_endpoint') +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.is_fakes3') +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_ceph_endpoint') +def test_s3_conn_params(m_parse_ceph_endpoint, + m_is_fakes3, + m_parse_fakes3_endpoint, + m_parse_default_endpoint, + scenario): + + url = "https://my-bucket.s3.us-west-2.amazonaws.com" + region = "us-east-1" + aws_connect_kwargs = {"aws_secret_key": "secret123!", "aws_access_key": "ABCDEFG"} + mode = "put" + encryption_mode = "aws:test" + dualstack = False + sig_4 = False + + endpoint = {"endpoint": url, "config": {"s3": True, "signature": "s123"}} + + ceph = bool(scenario == "ceph") + isfakes3 = bool(scenario == "fakes3") + + m_is_fakes3.return_value = isfakes3 + if ceph: + m_parse_ceph_endpoint.return_value = endpoint + elif isfakes3: + m_parse_fakes3_endpoint.return_value = endpoint + else: + m_parse_default_endpoint.return_value = endpoint + + expected = {"conn_type": "client", "resource": "s3", "region": region} + expected.update(aws_connect_kwargs) + expected.update(endpoint) + + assert expected == s3.s3_conn_params(mode, encryption_mode, dualstack, aws_connect_kwargs, region, ceph, url, sig_4) - assert s3.validate_bucket_name(module, "docexamplebucket1") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "log-delivery-march-2020") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "my-hosted-content") is True - assert not module.fail_json.called - - assert s3.validate_bucket_name(module, "docexamplewebsite.com") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "www.docexamplewebsite.com") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "my.example.s3.bucket") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "doc") is True - assert not module.fail_json.called - - module.fail_json.reset_mock() - s3.validate_bucket_name(module, "doc_example_bucket") - assert module.fail_json.called - - module.fail_json.reset_mock() - s3.validate_bucket_name(module, "DocExampleBucket") - assert module.fail_json.called - module.fail_json.reset_mock() - s3.validate_bucket_name(module, "doc-example-bucket-") - assert module.fail_json.called - s3.validate_bucket_name(module, "my") - assert module.fail_json.called + if ceph: + m_parse_ceph_endpoint.assert_called_with(url) + m_parse_fakes3_endpoint.assert_not_called() + m_parse_default_endpoint.assert_not_called() + elif isfakes3: + m_parse_fakes3_endpoint.assert_called_with(url) + m_parse_ceph_endpoint.assert_not_called() + m_parse_default_endpoint.assert_not_called() + else: + m_parse_default_endpoint.assert_called_with( + url, mode, encryption_mode, dualstack, sig_4 + ) + m_parse_ceph_endpoint.assert_not_called() + m_parse_fakes3_endpoint.assert_not_called() diff --git a/tests/unit/plugins/modules/test_s3_object.py b/tests/unit/plugins/modules/test_s3_object.py deleted file mode 100644 index b0251307229..00000000000 --- a/tests/unit/plugins/modules/test_s3_object.py +++ /dev/null @@ -1,29 +0,0 @@ -# Make coding more python3-ish -from __future__ import (absolute_import, division, print_function) -__metaclass__ = type - -from ansible.module_utils.six.moves.urllib.parse import urlparse - -from ansible_collections.amazon.aws.plugins.modules import s3_object - - -class TestUrlparse(): - - def test_urlparse(self): - actual = urlparse("http://test.com/here") - assert actual.scheme == "http" - assert actual.netloc == "test.com" - assert actual.path == "/here" - - def test_is_fakes3(self): - actual = s3_object.is_fakes3("fakes3://bla.blubb") - assert actual is True - - def test_get_s3_connection(self): - aws_connect_kwargs = dict(aws_access_key_id="access_key", - aws_secret_access_key="secret_key") - location = None - rgw = True - s3_url = "http://bla.blubb" - actual = s3_object.get_s3_connection(None, aws_connect_kwargs, location, rgw, s3_url) - assert "bla.blubb" in str(actual._endpoint)