Skip to content

Commit

Permalink
inventory/aws_ec2: extend unit-test coverage (ansible-collections#1093)
Browse files Browse the repository at this point in the history
inventory/aws_ec2: extend unit-test coverage

Depends-On: ansible-collections#1090
Break up the _populate() method into smaller function. Use @classmethod when possible and
reduce the use of get_option() to avoid side-effects.

Reviewed-by: Mark Chappell <None>
Reviewed-by: Gonéri Le Bouder <[email protected]>
Reviewed-by: Alina Buzachis <None>
  • Loading branch information
goneri authored Oct 3, 2022
1 parent 1cac4da commit 758d594
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 72 deletions.
3 changes: 3 additions & 0 deletions changelogs/fragments/inventory-aws_ec2_unit-tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
trivial:
- inventory/aws_ec2 - Expand unit tests.
124 changes: 79 additions & 45 deletions plugins/inventory/aws_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,8 @@ def _get_event_set_and_persistence(self, connection, instance_id, spot_instance)
raise AnsibleError("Failed to describe spot instance requests: %s" % to_native(e))
return host_vars

def _get_tag_hostname(self, preference, instance):
@classmethod
def _get_tag_hostname(cls, preference, instance):
tag_hostnames = preference.split('tag:', 1)[1]
if ',' in tag_hostnames:
tag_hostnames = tag_hostnames.split(',')
Expand Down Expand Up @@ -664,10 +665,7 @@ def _get_preferred_hostname(self, instance, hostnames):
if hostname:
break
if hostname:
if ':' in to_text(hostname):
return self._sanitize_group_name((to_text(hostname)))
else:
return to_text(hostname)
return self._sanitize_hostname(hostname)

def get_all_hostnames(self, instance, hostnames):
'''
Expand Down Expand Up @@ -732,23 +730,47 @@ def _query(self, regions, include_filters, exclude_filters, strict_permissions):

return {'aws_ec2': instances}

def _populate(self, groups, hostnames, allow_duplicated_hosts=False):
def _populate(self, groups, hostnames, allow_duplicated_hosts=False,
hostvars_prefix=None, hostvars_suffix=None,
use_contrib_script_compatible_ec2_tag_keys=False):
for group in groups:
group = self.inventory.add_group(group)
self._add_hosts(
hosts=groups[group],
group=group,
hostnames=hostnames,
allow_duplicated_hosts=allow_duplicated_hosts)
allow_duplicated_hosts=allow_duplicated_hosts,
hostvars_prefix=hostvars_prefix,
hostvars_suffix=hostvars_suffix,
use_contrib_script_compatible_ec2_tag_keys=use_contrib_script_compatible_ec2_tag_keys)
self.inventory.add_child('all', group)

def _add_hosts(self, hosts, group, hostnames, allow_duplicated_hosts=False):
'''
:param hosts: a list of hosts to be added to a group
:param group: the name of the group to which the hosts belong
:param hostnames: a list of hostname destination variables in order of preference
:param bool allow_duplicated_hosts: allow multiple copy for the same host with a different name
'''
@classmethod
def prepare_host_vars(cls, original_host_vars, hostvars_prefix=None, hostvars_suffix=None,
use_contrib_script_compatible_ec2_tag_keys=False):
host_vars = camel_dict_to_snake_dict(original_host_vars, ignore_list=['Tags'])
host_vars['tags'] = boto3_tag_list_to_ansible_dict(original_host_vars.get('Tags', []))

# Allow easier grouping by region
host_vars['placement']['region'] = host_vars['placement']['availability_zone'][:-1]

if use_contrib_script_compatible_ec2_tag_keys:
for k, v in host_vars['tags'].items():
host_vars["ec2_tag_%s" % k] = v

if hostvars_prefix or hostvars_suffix:
for hostvar, hostval in host_vars.copy().items():
del(host_vars[hostvar])
if hostvars_prefix:
hostvar = hostvars_prefix + hostvar
if hostvars_suffix:
hostvar = hostvar + hostvars_suffix
host_vars[hostvar] = hostval

return host_vars

def iter_entry(self, hosts, hostnames, allow_duplicated_hosts=False, hostvars_prefix=None,
hostvars_suffix=None, use_contrib_script_compatible_ec2_tag_keys=False):
for host in hosts:
if allow_duplicated_hosts:
hostname_list = self.get_all_hostnames(host, hostnames)
Expand All @@ -757,46 +779,48 @@ def _add_hosts(self, hosts, group, hostnames, allow_duplicated_hosts=False):
else:
continue

host = camel_dict_to_snake_dict(host, ignore_list=['Tags'])
host['tags'] = boto3_tag_list_to_ansible_dict(host.get('tags', []))
host_vars = self.prepare_host_vars(
host,
hostvars_prefix,
hostvars_suffix,
use_contrib_script_compatible_ec2_tag_keys)
for name in hostname_list:
yield to_text(name), host_vars

if self.get_option('use_contrib_script_compatible_ec2_tag_keys'):
for k, v in host['tags'].items():
host["ec2_tag_%s" % k] = v

# Allow easier grouping by region
host['placement']['region'] = host['placement']['availability_zone'][:-1]
def _add_hosts(self, hosts, group, hostnames, allow_duplicated_hosts=False,
hostvars_prefix=None, hostvars_suffix=None, use_contrib_script_compatible_ec2_tag_keys=False):
'''
:param hosts: a list of hosts to be added to a group
:param group: the name of the group to which the hosts belong
:param hostnames: a list of hostname destination variables in order of preference
:param bool allow_duplicated_hosts: if true, accept same host with different names
:param str hostvars_prefix: starts the hostvars variable name with this prefix
:param str hostvars_suffix: ends the hostvars variable name with this suffix
:param bool use_contrib_script_compatible_ec2_tag_keys: transform the host name with the legacy naming system
'''

if not hostname_list:
continue
for hostname in hostname_list:
self.inventory.add_host(to_text(hostname), group=group)
hostvars_prefix = self.get_option("hostvars_prefix")
hostvars_suffix = self.get_option("hostvars_suffix")
new_vars = dict()
for hostvar, hostval in host.items():
if hostvars_prefix:
hostvar = hostvars_prefix + hostvar
if hostvars_suffix:
hostvar = hostvar + hostvars_suffix
new_vars[hostvar] = hostval
for hostname in hostname_list:
self.inventory.set_variable(to_text(hostname), hostvar, hostval)
host.update(new_vars)
for name, host_vars in self.iter_entry(
hosts, hostnames,
allow_duplicated_hosts=allow_duplicated_hosts,
hostvars_prefix=hostvars_prefix,
hostvars_suffix=hostvars_suffix,
use_contrib_script_compatible_ec2_tag_keys=use_contrib_script_compatible_ec2_tag_keys):
self.inventory.add_host(name, group=group)
for k, v in host_vars.items():
self.inventory.set_variable(name, k, v)

# Use constructed if applicable

strict = self.get_option('strict')

# Composed variables
for hostname in hostname_list:
self._set_composite_vars(self.get_option('compose'), host, to_text(hostname), strict=strict)
self._set_composite_vars(self.get_option('compose'), host_vars, name, strict=strict)

# Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group
self._add_host_to_composed_groups(self.get_option('groups'), host, to_text(hostname), strict=strict)
# Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group
self._add_host_to_composed_groups(self.get_option('groups'), host_vars, name, strict=strict)

# Create groups based on variable values and add the corresponding hosts to it
self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host, to_text(hostname), strict=strict)
# Create groups based on variable values and add the corresponding hosts to it
self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host_vars, name, strict=strict)

def _set_credentials(self, loader):
'''
Expand Down Expand Up @@ -875,6 +899,10 @@ def parse(self, inventory, loader, path, cache=True):
strict_permissions = self.get_option('strict_permissions')
allow_duplicated_hosts = self.get_option('allow_duplicated_hosts')

hostvars_prefix = self.get_option("hostvars_prefix")
hostvars_suffix = self.get_option("hostvars_suffix")
use_contrib_script_compatible_ec2_tag_keys = self.get_option('use_contrib_script_compatible_ec2_tag_keys')

cache_key = self.get_cache_key(path)
# false when refresh_cache or --flush-cache is used
if cache:
Expand All @@ -893,7 +921,13 @@ def parse(self, inventory, loader, path, cache=True):
if not cache or cache_needs_update:
results = self._query(regions, include_filters, exclude_filters, strict_permissions)

self._populate(results, hostnames, allow_duplicated_hosts=allow_duplicated_hosts)
self._populate(
results,
hostnames,
allow_duplicated_hosts=allow_duplicated_hosts,
hostvars_prefix=hostvars_prefix,
hostvars_suffix=hostvars_suffix,
use_contrib_script_compatible_ec2_tag_keys=use_contrib_script_compatible_ec2_tag_keys)

# If the cache has expired/doesn't exist or if refresh_inventory/flush cache is used
# when the user is using caching, update the cached inventory
Expand Down
Loading

0 comments on commit 758d594

Please sign in to comment.