Skip to content

Commit

Permalink
[Storage] Add proxy-supporting test classes (#24937)
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp authored Jun 23, 2022
1 parent a7a18dd commit d653189
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 12 deletions.
13 changes: 10 additions & 3 deletions tools/azure-sdk-tools/devtools_testutils/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .api_version_policy import ApiVersionAssertPolicy
from .service_versions import service_version_map, ServiceVersion, is_version_before
from .testcase import StorageTestCase, LogCaptured
from .testcase import StorageTestCase, StorageRecordedTestCase, LogCaptured

__all__ = ["ApiVersionAssertPolicy", "service_version_map", "StorageTestCase", "ServiceVersion", "is_version_before",
"LogCaptured"]
__all__ = [
"ApiVersionAssertPolicy",
"service_version_map",
"StorageTestCase",
"StorageRecordedTestCase",
"ServiceVersion",
"is_version_before",
"LogCaptured"
]
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .asynctestcase import AsyncStorageTestCase
from .asynctestcase import AsyncStorageTestCase, AsyncStorageRecordedTestCase

__all__ = ["AsyncStorageTestCase"]
__all__ = ["AsyncStorageTestCase", "AsyncStorageRecordedTestCase"]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import functools

from .. import StorageTestCase
from .. import StorageTestCase, StorageRecordedTestCase
from ...fake_credentials_async import AsyncFakeCredential

from azure_devtools.scenario_tests.patches import mock_in_unit_test
Expand Down Expand Up @@ -67,3 +67,34 @@ def generate_oauth_token(self):

def generate_fake_token(self):
return AsyncFakeCredential()


class AsyncStorageRecordedTestCase(StorageRecordedTestCase):

@staticmethod
def await_prepared_test(test_fn):
"""Synchronous wrapper for async test methods. Used to avoid making changes
upstream to AbstractPreparer (which doesn't await the functions it wraps)
"""

@functools.wraps(test_fn)
def run(test_class_instance, *args, **kwargs):
trim_kwargs_from_test_function(test_fn, kwargs)
loop = asyncio.get_event_loop()
return loop.run_until_complete(test_fn(test_class_instance, **kwargs))

return run

def generate_oauth_token(self):
if self.is_live:
from azure.identity.aio import ClientSecretCredential

return ClientSecretCredential(
self.get_settings_value("TENANT_ID"),
self.get_settings_value("CLIENT_ID"),
self.get_settings_value("CLIENT_SECRET"),
)
return self.generate_fake_token()

def generate_fake_token(self):
return AsyncFakeCredential()
176 changes: 170 additions & 6 deletions tools/azure-sdk-tools/devtools_testutils/storage/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import division

from datetime import datetime, timedelta
from io import StringIO
import logging
import math
import os
Expand All @@ -15,17 +16,14 @@
import time
import zlib

from devtools_testutils import AzureTestCase
import pytest

from devtools_testutils import AzureTestCase, AzureRecordedTestCase

from .processors import XMSRequestIDBody
from . import ApiVersionAssertPolicy, service_version_map
from .. import FakeTokenCredential

try:
from cStringIO import StringIO # Python 2
except ImportError:
from io import StringIO

try:
from azure.storage.blob import generate_account_sas, AccountSasPermissions, ResourceTypes
except:
Expand All @@ -39,6 +37,19 @@
ENABLE_LOGGING = True


def generate_sas_token():
fake_key = "a" * 30 + "b" * 30

return "?" + generate_account_sas(
account_name="test", # name of the storage account
account_key=fake_key, # key for the storage account
resource_types=ResourceTypes(object=True),
permission=AccountSasPermissions(read=True, list=True),
start=datetime.now() - timedelta(hours=24),
expiry=datetime.now() + timedelta(days=8),
)


class StorageTestCase(AzureTestCase):
def __init__(self, *args, **kwargs):
super(StorageTestCase, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -209,6 +220,159 @@ def create_storage_client_from_conn_str(self, client, *args, **kwargs):
return client.from_connection_string(*args, **kwargs)


class StorageRecordedTestCase(AzureRecordedTestCase):

def setup_class(cls):
cls.logger = logging.getLogger("azure.storage")
cls.sas_token = generate_sas_token()

def setup_method(self, _):
self.configure_logging()

def connection_string(self, account_name, key):
return (
"DefaultEndpointsProtocol=https;AcCounTName="
+ account_name
+ ";AccOuntKey="
+ str(key)
+ ";EndpoIntSuffix=core.windows.net"
)

def account_url(self, storage_account, storage_type):
"""Return an url of storage account.
:param str storage_account: Storage account name
:param str storage_type: The Storage type part of the URL. Should be "blob", or "queue", etc.
"""
protocol = os.environ.get("PROTOCOL", "https")
suffix = os.environ.get("ACCOUNT_URL_SUFFIX", "core.windows.net")
return f"{protocol}://{storage_account}.{storage_type}.{suffix}"

def configure_logging(self):
enable_logging = ENABLE_LOGGING

self.enable_logging() if enable_logging else self.disable_logging()

def enable_logging(self):
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
self.logger.handlers = [handler]
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = True
self.logger.disabled = False

def disable_logging(self):
self.logger.propagate = False
self.logger.disabled = True
self.logger.handlers = []

def get_random_bytes(self, size):
# recordings don't like random stuff. making this more
# deterministic.
return b"a" * size

def get_random_text_data(self, size):
"""Returns random unicode text data exceeding the size threshold for
chunking blob upload."""
checksum = zlib.adler32(self.qualified_test_name.encode()) & 0xFFFFFFFF
rand = random.Random(checksum)
text = u""
words = [u"hello", u"world", u"python", u"啊齄丂狛狜"]
while len(text) < size:
index = int(rand.random() * (len(words) - 1))
text = text + u" " + words[index]

return text

@staticmethod
def _set_test_proxy(service, settings):
if settings.USE_PROXY:
service.set_proxy(
settings.PROXY_HOST,
settings.PROXY_PORT,
settings.PROXY_USER,
settings.PROXY_PASSWORD,
)

def assertNamedItemInContainer(self, container, item_name, msg=None):
def _is_string(obj):
return isinstance(obj, str)

for item in container:
if _is_string(item):
if item == item_name:
return
elif isinstance(item, dict):
if item_name == item["name"]:
return
elif item.name == item_name:
return
elif hasattr(item, "snapshot") and item.snapshot == item_name:
return

error_message = f"{repr(item_name)} not found in {[str(c) for c in container]}"
pytest.fail(error_message)

def assertNamedItemNotInContainer(self, container, item_name, msg=None):
for item in container:
if item.name == item_name:
error_message = f"{repr(item_name)} unexpectedly found in {repr(container)}"
pytest.fail(error_message)

def assert_upload_progress(self, size, max_chunk_size, progress, unknown_size=False):
"""Validates that the progress chunks align with our chunking procedure."""
total = None if unknown_size else size
small_chunk_size = size % max_chunk_size
assert len(progress) == math.ceil(size / max_chunk_size)
for i in progress:
assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size
assert i[1] == total

def assert_download_progress(self, size, max_chunk_size, max_get_size, progress):
"""Validates that the progress chunks align with our chunking procedure."""
if size <= max_get_size:
assert len(progress) == 1
assert progress[0][0], size
assert progress[0][1], size
else:
small_chunk_size = (size - max_get_size) % max_chunk_size
assert len(progress) == 1 + math.ceil((size - max_get_size) / max_chunk_size)

assert progress[0][0], max_get_size
assert progress[0][1], size
for i in progress[1:]:
assert i[0] % max_chunk_size == 0 or i[0] % max_chunk_size == small_chunk_size
assert i[1] == size

def generate_oauth_token(self):
if self.is_live:
from azure.identity import ClientSecretCredential

return ClientSecretCredential(
self.get_settings_value("TENANT_ID"),
self.get_settings_value("CLIENT_ID"),
self.get_settings_value("CLIENT_SECRET"),
)
return self.generate_fake_token()

def generate_fake_token(self):
return FakeTokenCredential()

def _get_service_version(self, **kwargs):
env_version = service_version_map.get(os.environ.get("AZURE_LIVE_TEST_SERVICE_VERSION", "LATEST"))
return kwargs.pop("service_version", env_version)

def create_storage_client(self, client, *args, **kwargs):
kwargs["api_version"] = self._get_service_version(**kwargs)
kwargs["_additional_pipeline_policies"] = [ApiVersionAssertPolicy(kwargs["api_version"])]
return client(*args, **kwargs)

def create_storage_client_from_conn_str(self, client, *args, **kwargs):
kwargs["api_version"] = self._get_service_version(**kwargs)
kwargs["_additional_pipeline_policies"] = [ApiVersionAssertPolicy(kwargs["api_version"])]
return client.from_connection_string(*args, **kwargs)


class LogCaptured(object):
def __init__(self, test_case=None):
# accept the test case so that we may reset logging after capturing logs
Expand Down

0 comments on commit d653189

Please sign in to comment.