Skip to content

Commit

Permalink
Refactor module utils and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
abikouo committed Nov 14, 2022
1 parent f1e16a2 commit d894aaa
Show file tree
Hide file tree
Showing 3 changed files with 632 additions and 110 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
minor_changes:
- Refactor module_utils/cloudfront_facts.py and add unit tests (https://github.com/ansible-collections/amazon.aws/pull/1265).

230 changes: 120 additions & 110 deletions plugins/module_utils/cloudfront_facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,109 +26,137 @@
Common cloudfront facts shared between modules
"""

from functools import partial
try:
import botocore
except ImportError:
pass

from .ec2 import AWSRetry
from .ec2 import boto3_tag_list_to_ansible_dict
from .ec2 import snake_dict_to_camel_dict


class CloudFrontFactsServiceManager:
"""Handles CloudFront Facts Services"""
class CloudFrontFactsServiceManagerFailure(Exception):
pass

def __init__(self, module):
self.module = module
self.client = module.client('cloudfront', retry_decorator=AWSRetry.jittered_backoff())

def get_distribution(self, distribution_id):
try:
return self.client.get_distribution(Id=distribution_id, aws_retry=True)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error describing distribution")
def cloudfront_facts_keyed_list_helper(list_to_key):
keyed_list = dict()
for item in list_to_key:
distribution_id = item['Id']
if 'Items' in item['Aliases']:
aliases = item['Aliases']['Items']
for alias in aliases:
keyed_list.update({alias: item})
keyed_list.update({distribution_id: item})
return keyed_list

def get_distribution_config(self, distribution_id):
try:
return self.client.get_distribution_config(Id=distribution_id, aws_retry=True)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error describing distribution configuration")

def get_origin_access_identity(self, origin_access_identity_id):
try:
return self.client.get_cloud_front_origin_access_identity(Id=origin_access_identity_id, aws_retry=True)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error describing origin access identity")
def _cloudfront_paginate_build_full_result(client, client_method, **kwargs):
print("Inside this")
paginator = client.get_paginator(client_method)
return paginator.paginate(**kwargs).build_full_result()

def get_origin_access_identity_config(self, origin_access_identity_id):
try:
return self.client.get_cloud_front_origin_access_identity_config(Id=origin_access_identity_id, aws_retry=True)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error describing origin access identity configuration")

def get_invalidation(self, distribution_id, invalidation_id):
try:
return self.client.get_invalidation(DistributionId=distribution_id, Id=invalidation_id, aws_retry=True)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error describing invalidation")
class CloudFrontFactsServiceManager:
"""Handles CloudFront Facts Services"""

def get_streaming_distribution(self, distribution_id):
try:
return self.client.get_streaming_distribution(Id=distribution_id, aws_retry=True)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error describing streaming distribution")
CLOUDFRONT_CLIENT_API_MAPPING = {
"get_distribution": {
"error": "Error describing distribution",
},
"get_distribution_config": {
"error": "Error describing distribution configuration",
},
"get_origin_access_identity": {
"error": "Error describing origin access identity",
"client_api": "get_cloud_front_origin_access_identity"
},
"get_origin_access_identity_config": {
"error": "Error describing origin access identity configuration",
"client_api": "get_cloud_front_origin_access_identity_config"
},
"get_streaming_distribution": {
"error": "Error describing streaming distribution",
},
"get_streaming_distribution_config": {
"error": "Error describing streaming distribution",
},
"get_invalidation": {
"error": "Error describing invalidation"
},
"list_distributions_by_web_acl_id": {
"error": "Error listing distributions by web acl id",
"post_process": lambda x: cloudfront_facts_keyed_list_helper(x.get('DistributionList', {}).get('Items', []))
}
}

CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING = {
"list_origin_access_identities": {
"error": "Error listing cloud front origin access identities",
"client_api": "list_cloud_front_origin_access_identities",
"key": "CloudFrontOriginAccessIdentityList"
},
"list_distributions": {
"error": "Error listing distributions",
"key": "DistributionList",
"keyed": True,
},
"list_invalidations": {
"error": "Error listing invalidations",
"key": "InvalidationList"
},
"list_streaming_distributions": {
"error": "Error listing streaming distributions",
"key": "StreamingDistributionList",
"keyed": True,
}
}

def get_streaming_distribution_config(self, distribution_id):
try:
return self.client.get_streaming_distribution_config(Id=distribution_id, aws_retry=True)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error describing streaming distribution")
def __init__(self, module):
self.module = module
self.client = module.client('cloudfront', retry_decorator=AWSRetry.jittered_backoff())

def list_origin_access_identities(self):
def describe_cloudfront_property(self, client_method, error, post_process, **kwargs):
try:
paginator = self.client.get_paginator('list_cloud_front_origin_access_identities')
result = paginator.paginate().build_full_result().get('CloudFrontOriginAccessIdentityList', {})
return result.get('Items', [])
method = getattr(self.client, client_method)
api_kwargs = snake_dict_to_camel_dict(kwargs, capitalize_first=True)
result = method(aws_retry=True, **api_kwargs)
if post_process:
result = post_process(result)
return result
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error listing cloud front origin access identities")
self.module.fail_json_aws(e, msg=error)

def list_distributions(self, keyed=True):
def paginate_list_cloudfront_property(self, client_method, key, keyed, error, **kwargs):
try:
paginator = self.client.get_paginator('list_distributions')
result = paginator.paginate().build_full_result().get('DistributionList', {})
distribution_list = result.get('Items', [])
if not keyed:
return distribution_list
return self.keyed_list_helper(distribution_list)
keyed = kwargs.pop("keyed", keyed)
api_kwargs = snake_dict_to_camel_dict(kwargs, capitalize_first=True)
result = _cloudfront_paginate_build_full_result(self.client, client_method, **api_kwargs)
items = result.get(key, {}).get('Items', [])
if keyed:
items = cloudfront_facts_keyed_list_helper(items)
return items
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error listing distributions")
self.module.fail_json_aws(e, msg=error)

def list_distributions_by_web_acl_id(self, web_acl_id):
try:
result = self.client.list_distributions_by_web_acl_id(WebAclId=web_acl_id, aws_retry=True)
distribution_list = result.get('DistributionList', {}).get('Items', [])
return self.keyed_list_helper(distribution_list)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error listing distributions by web acl id")
def __getattr__(self, name):

def list_invalidations(self, distribution_id):
try:
paginator = self.client.get_paginator('list_invalidations')
result = paginator.paginate(DistributionId=distribution_id).build_full_result()
return result.get('InvalidationList', {}).get('Items', [])
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error listing invalidations")
if name in self.CLOUDFRONT_CLIENT_API_MAPPING:
client_method = self.CLOUDFRONT_CLIENT_API_MAPPING[name].get('client_api', name)
error = self.CLOUDFRONT_CLIENT_API_MAPPING[name].get('error', '')
post_process = self.CLOUDFRONT_CLIENT_API_MAPPING[name].get('post_process')
return partial(self.describe_cloudfront_property, client_method, error, post_process)

def list_streaming_distributions(self, keyed=True):
try:
paginator = self.client.get_paginator('list_streaming_distributions')
result = paginator.paginate().build_full_result()
streaming_distribution_list = result.get('StreamingDistributionList', {}).get('Items', [])
if not keyed:
return streaming_distribution_list
return self.keyed_list_helper(streaming_distribution_list)
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error listing streaming distributions")
elif name in self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING:
client_method = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('client_api', name)
error = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('error', '')
key = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('key')
keyed = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('keyed', False)
return partial(self.paginate_list_cloudfront_property, client_method, key, keyed, error)

raise CloudFrontFactsServiceManagerFailure("Method {0} is not currently supported".format(name))

def summary(self):
summary_dict = {}
Expand All @@ -139,35 +167,35 @@ def summary(self):

def summary_get_origin_access_identity_list(self):
try:
origin_access_identity_list = {'origin_access_identities': []}
origin_access_identities = self.list_origin_access_identities()
for origin_access_identity in origin_access_identities:
origin_access_identities = []
for origin_access_identity in self.list_origin_access_identities():
oai_id = origin_access_identity['Id']
oai_full_response = self.get_origin_access_identity(oai_id)
oai_summary = {'Id': oai_id, 'ETag': oai_full_response['ETag']}
origin_access_identity_list['origin_access_identities'].append(oai_summary)
return origin_access_identity_list
origin_access_identities.append(oai_summary)
return {'origin_access_identities': origin_access_identities}
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error generating summary of origin access identities")

def list_resource_tags(self, resource_arn):
return self.client.list_tags_for_resource(Resource=resource_arn, aws_retry=True)

def summary_get_distribution_list(self, streaming=False):
try:
list_name = 'streaming_distributions' if streaming else 'distributions'
key_list = ['Id', 'ARN', 'Status', 'LastModifiedTime', 'DomainName', 'Comment', 'PriceClass', 'Enabled']
distribution_list = {list_name: []}
distributions = self.list_streaming_distributions(False) if streaming else self.list_distributions(False)
for dist in distributions:
temp_distribution = {}
for key_name in key_list:
temp_distribution[key_name] = dist[key_name]
temp_distribution = {k: dist[k] for k in key_list}
temp_distribution['Aliases'] = list(dist['Aliases'].get('Items', []))
temp_distribution['ETag'] = self.get_etag_from_distribution_id(dist['Id'], streaming)
if not streaming:
temp_distribution['WebACLId'] = dist['WebACLId']
invalidation_ids = self.get_list_of_invalidation_ids_from_distribution_id(dist['Id'])
if invalidation_ids:
temp_distribution['Invalidations'] = invalidation_ids
resource_tags = self.client.list_tags_for_resource(Resource=dist['ARN'], aws_retry=True)
resource_tags = self.list_resource_tags(dist['ARN'])
temp_distribution['Tags'] = boto3_tag_list_to_ansible_dict(resource_tags['Tags'].get('Items', []))
distribution_list[list_name].append(temp_distribution)
return distribution_list
Expand All @@ -177,18 +205,14 @@ def summary_get_distribution_list(self, streaming=False):
def get_etag_from_distribution_id(self, distribution_id, streaming):
distribution = {}
if not streaming:
distribution = self.get_distribution(distribution_id)
distribution = self.get_distribution(id=distribution_id)
else:
distribution = self.get_streaming_distribution(distribution_id)
distribution = self.get_streaming_distribution(id=distribution_id)
return distribution['ETag']

def get_list_of_invalidation_ids_from_distribution_id(self, distribution_id):
try:
invalidation_ids = []
invalidations = self.list_invalidations(distribution_id)
for invalidation in invalidations:
invalidation_ids.append(invalidation['Id'])
return invalidation_ids
return list(map(lambda x: x['Id'], self.list_invalidations(distribution_id=distribution_id)))
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error getting list of invalidation ids")

Expand All @@ -198,29 +222,15 @@ def get_distribution_id_from_domain_name(self, domain_name):
distributions = self.list_distributions(False)
distributions += self.list_streaming_distributions(False)
for dist in distributions:
if 'Items' in dist['Aliases']:
for alias in dist['Aliases']['Items']:
if str(alias).lower() == domain_name.lower():
distribution_id = dist['Id']
break
if any(str(alias).lower() == domain_name.lower() for alias in dist['Aliases'].get('Items', [])):
distribution_id = dist['Id']
return distribution_id
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error getting distribution id from domain name")

def get_aliases_from_distribution_id(self, distribution_id):
try:
distribution = self.get_distribution(distribution_id)
distribution = self.get_distribution(id=distribution_id)
return distribution['DistributionConfig']['Aliases'].get('Items', [])
except botocore.exceptions.ClientError as e:
self.module.fail_json_aws(e, msg="Error getting list of aliases from distribution_id")

def keyed_list_helper(self, list_to_key):
keyed_list = dict()
for item in list_to_key:
distribution_id = item['Id']
if 'Items' in item['Aliases']:
aliases = item['Aliases']['Items']
for alias in aliases:
keyed_list.update({alias: item})
keyed_list.update({distribution_id: item})
return keyed_list
Loading

0 comments on commit d894aaa

Please sign in to comment.