diff --git a/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py b/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py index 05543cff09c2..68cc432a41c5 100644 --- a/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py +++ b/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py @@ -1,6 +1,13 @@ from .api_version_policy import ApiVersionAssertPolicy from .service_versions import service_version_map, ServiceVersion, is_version_before -from .testcase import StorageTestCase, LogCaptured +from .testcase import StorageTestCase, StorageRecordedTestCase, LogCaptured -__all__ = ["ApiVersionAssertPolicy", "service_version_map", "StorageTestCase", "ServiceVersion", "is_version_before", - "LogCaptured"] +__all__ = [ + "ApiVersionAssertPolicy", + "service_version_map", + "StorageTestCase", + "StorageRecordedTestCase", + "ServiceVersion", + "is_version_before", + "LogCaptured" +] diff --git a/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py b/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py index ee8d633673cf..fd10bf28ba22 100644 --- a/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py +++ b/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py @@ -1,3 +1,3 @@ -from .asynctestcase import AsyncStorageTestCase +from .asynctestcase import AsyncStorageTestCase, AsyncStorageRecordedTestCase -__all__ = ["AsyncStorageTestCase"] +__all__ = ["AsyncStorageTestCase", "AsyncStorageRecordedTestCase"] diff --git a/tools/azure-sdk-tools/devtools_testutils/storage/aio/asynctestcase.py b/tools/azure-sdk-tools/devtools_testutils/storage/aio/asynctestcase.py index e24404bb4d40..7eef909dbc15 100644 --- a/tools/azure-sdk-tools/devtools_testutils/storage/aio/asynctestcase.py +++ b/tools/azure-sdk-tools/devtools_testutils/storage/aio/asynctestcase.py @@ -1,7 +1,7 @@ import asyncio import functools -from .. import StorageTestCase +from .. import StorageTestCase, StorageRecordedTestCase from ...fake_credentials_async import AsyncFakeCredential from azure_devtools.scenario_tests.patches import mock_in_unit_test @@ -67,3 +67,34 @@ def generate_oauth_token(self): def generate_fake_token(self): return AsyncFakeCredential() + + +class AsyncStorageRecordedTestCase(StorageRecordedTestCase): + + @staticmethod + def await_prepared_test(test_fn): + """Synchronous wrapper for async test methods. Used to avoid making changes + upstream to AbstractPreparer (which doesn't await the functions it wraps) + """ + + @functools.wraps(test_fn) + def run(test_class_instance, *args, **kwargs): + trim_kwargs_from_test_function(test_fn, kwargs) + loop = asyncio.get_event_loop() + return loop.run_until_complete(test_fn(test_class_instance, **kwargs)) + + return run + + def generate_oauth_token(self): + if self.is_live: + from azure.identity.aio import ClientSecretCredential + + return ClientSecretCredential( + self.get_settings_value("TENANT_ID"), + self.get_settings_value("CLIENT_ID"), + self.get_settings_value("CLIENT_SECRET"), + ) + return self.generate_fake_token() + + def generate_fake_token(self): + return AsyncFakeCredential() diff --git a/tools/azure-sdk-tools/devtools_testutils/storage/testcase.py b/tools/azure-sdk-tools/devtools_testutils/storage/testcase.py index 69d2b503ef43..55286039e6b8 100644 --- a/tools/azure-sdk-tools/devtools_testutils/storage/testcase.py +++ b/tools/azure-sdk-tools/devtools_testutils/storage/testcase.py @@ -7,6 +7,7 @@ from __future__ import division from datetime import datetime, timedelta +from io import StringIO import logging import math import os @@ -15,17 +16,14 @@ import time import zlib -from devtools_testutils import AzureTestCase +import pytest + +from devtools_testutils import AzureTestCase, AzureRecordedTestCase from .processors import XMSRequestIDBody from . import ApiVersionAssertPolicy, service_version_map from .. import FakeTokenCredential -try: - from cStringIO import StringIO # Python 2 -except ImportError: - from io import StringIO - try: from azure.storage.blob import generate_account_sas, AccountSasPermissions, ResourceTypes except: @@ -39,6 +37,19 @@ ENABLE_LOGGING = True +def generate_sas_token(): + fake_key = "a" * 30 + "b" * 30 + + return "?" + generate_account_sas( + account_name="test", # name of the storage account + account_key=fake_key, # key for the storage account + resource_types=ResourceTypes(object=True), + permission=AccountSasPermissions(read=True, list=True), + start=datetime.now() - timedelta(hours=24), + expiry=datetime.now() + timedelta(days=8), + ) + + class StorageTestCase(AzureTestCase): def __init__(self, *args, **kwargs): super(StorageTestCase, self).__init__(*args, **kwargs) @@ -209,6 +220,159 @@ def create_storage_client_from_conn_str(self, client, *args, **kwargs): return client.from_connection_string(*args, **kwargs) +class StorageRecordedTestCase(AzureRecordedTestCase): + + def setup_class(cls): + cls.logger = logging.getLogger("azure.storage") + cls.sas_token = generate_sas_token() + + def setup_method(self, _): + self.configure_logging() + + def connection_string(self, account_name, key): + return ( + "DefaultEndpointsProtocol=https;AcCounTName=" + + account_name + + ";AccOuntKey=" + + str(key) + + ";EndpoIntSuffix=core.windows.net" + ) + + def account_url(self, storage_account, storage_type): + """Return an url of storage account. + + :param str storage_account: Storage account name + :param str storage_type: The Storage type part of the URL. Should be "blob", or "queue", etc. + """ + protocol = os.environ.get("PROTOCOL", "https") + suffix = os.environ.get("ACCOUNT_URL_SUFFIX", "core.windows.net") + return f"{protocol}://{storage_account}.{storage_type}.{suffix}" + + def configure_logging(self): + enable_logging = ENABLE_LOGGING + + self.enable_logging() if enable_logging else self.disable_logging() + + def enable_logging(self): + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) + self.logger.handlers = [handler] + self.logger.setLevel(logging.DEBUG) + self.logger.propagate = True + self.logger.disabled = False + + def disable_logging(self): + self.logger.propagate = False + self.logger.disabled = True + self.logger.handlers = [] + + def get_random_bytes(self, size): + # recordings don't like random stuff. making this more + # deterministic. + return b"a" * size + + def get_random_text_data(self, size): + """Returns random unicode text data exceeding the size threshold for + chunking blob upload.""" + checksum = zlib.adler32(self.qualified_test_name.encode()) & 0xFFFFFFFF + rand = random.Random(checksum) + text = u"" + words = [u"hello", u"world", u"python", u"啊齄丂狛狜"] + while len(text) < size: + index = int(rand.random() * (len(words) - 1)) + text = text + u" " + words[index] + + return text + + @staticmethod + def _set_test_proxy(service, settings): + if settings.USE_PROXY: + service.set_proxy( + settings.PROXY_HOST, + settings.PROXY_PORT, + settings.PROXY_USER, + settings.PROXY_PASSWORD, + ) + + def assertNamedItemInContainer(self, container, item_name, msg=None): + def _is_string(obj): + return isinstance(obj, str) + + for item in container: + if _is_string(item): + if item == item_name: + return + elif isinstance(item, dict): + if item_name == item["name"]: + return + elif item.name == item_name: + return + elif hasattr(item, "snapshot") and item.snapshot == item_name: + return + + error_message = f"{repr(item_name)} not found in {[str(c) for c in container]}" + pytest.fail(error_message) + + def assertNamedItemNotInContainer(self, container, item_name, msg=None): + for item in container: + if item.name == item_name: + error_message = f"{repr(item_name)} unexpectedly found in {repr(container)}" + pytest.fail(error_message) + + def assert_upload_progress(self, size, max_chunk_size, progress, unknown_size=False): + """Validates that the progress chunks align with our chunking procedure.""" + total = None if unknown_size else size + small_chunk_size = size % max_chunk_size + assert len(progress) == math.ceil(size / max_chunk_size) + for i in progress: + assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size + assert i[1] == total + + def assert_download_progress(self, size, max_chunk_size, max_get_size, progress): + """Validates that the progress chunks align with our chunking procedure.""" + if size <= max_get_size: + assert len(progress) == 1 + assert progress[0][0], size + assert progress[0][1], size + else: + small_chunk_size = (size - max_get_size) % max_chunk_size + assert len(progress) == 1 + math.ceil((size - max_get_size) / max_chunk_size) + + assert progress[0][0], max_get_size + assert progress[0][1], size + for i in progress[1:]: + assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size + assert i[1] == size + + def generate_oauth_token(self): + if self.is_live: + from azure.identity import ClientSecretCredential + + return ClientSecretCredential( + self.get_settings_value("TENANT_ID"), + self.get_settings_value("CLIENT_ID"), + self.get_settings_value("CLIENT_SECRET"), + ) + return self.generate_fake_token() + + def generate_fake_token(self): + return FakeTokenCredential() + + def _get_service_version(self, **kwargs): + env_version = service_version_map.get(os.environ.get("AZURE_LIVE_TEST_SERVICE_VERSION", "LATEST")) + return kwargs.pop("service_version", env_version) + + def create_storage_client(self, client, *args, **kwargs): + kwargs["api_version"] = self._get_service_version(**kwargs) + kwargs["_additional_pipeline_policies"] = [ApiVersionAssertPolicy(kwargs["api_version"])] + return client(*args, **kwargs) + + def create_storage_client_from_conn_str(self, client, *args, **kwargs): + kwargs["api_version"] = self._get_service_version(**kwargs) + kwargs["_additional_pipeline_policies"] = [ApiVersionAssertPolicy(kwargs["api_version"])] + return client.from_connection_string(*args, **kwargs) + + class LogCaptured(object): def __init__(self, test_case=None): # accept the test case so that we may reset logging after capturing logs