From 4ba7a73a9eccb253d869bb53cfba00e5dcd51aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gon=C3=A9ri=20Le=20Bouder?= Date: Mon, 12 Sep 2022 14:04:27 -0400 Subject: [PATCH] tests: add unit-tests for calculate_etag() Add test coverage for `calucalte_etag` that cover the two cases (with or without multipart). --- tests/unit/module_utils/test_s3.py | 42 ++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/unit/module_utils/test_s3.py b/tests/unit/module_utils/test_s3.py index b078f718db2..42c8ecfd024 100644 --- a/tests/unit/module_utils/test_s3.py +++ b/tests/unit/module_utils/test_s3.py @@ -5,10 +5,52 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import (absolute_import, division, print_function) + __metaclass__ = type from ansible_collections.amazon.aws.tests.unit.compat.mock import MagicMock from ansible_collections.amazon.aws.plugins.module_utils import s3 +from ansible.module_utils.basic import AnsibleModule + +import pytest + + +class FakeAnsibleModule(AnsibleModule): + def __init__(self): + pass + + +def test_calculate_etag_single_part(tmp_path_factory): + module = FakeAnsibleModule() + my_image = tmp_path_factory.mktemp("data") / "my.txt" + my_image.write_text("Hello World!") + + etag = s3.calculate_etag( + module, str(my_image), etag="", s3=None, bucket=None, obj=None + ) + assert etag == '"ed076287532e86365e841e92bfc50d8c"' + + +def test_calculate_etag_multi_part(tmp_path_factory): + module = FakeAnsibleModule() + my_image = tmp_path_factory.mktemp("data") / "my.txt" + my_image.write_text("Hello World!" * 1000) + + mocked_s3 = MagicMock() + mocked_s3.head_object.side_effect = [{"ContentLength": "1000"} for _i in range(12)] + + etag = s3.calculate_etag( + module, + str(my_image), + etag='"f20e84ac3d0c33cea77b3f29e3323a09-12"', + s3=mocked_s3, + bucket="my-bucket", + obj="my-obj", + ) + assert etag == '"f20e84ac3d0c33cea77b3f29e3323a09-12"' + mocked_s3.head_object.assert_called_with( + Bucket="my-bucket", Key="my-obj", PartNumber=12 + ) def test_validate_bucket_name():