Skip to content

Commit

Permalink
ansible: add "groups" key to get_variables()
Browse files Browse the repository at this point in the history
This key was returned by ansible python API in testinfra 2.X and users
rely on it.

Closes #443
  • Loading branch information
philpep committed May 9, 2019
1 parent e01ccc6 commit febf654
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
7 changes: 7 additions & 0 deletions test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,25 @@ def test_ansible_get_variables():

def get_vars(host):
return AnsibleRunner(f.name).get_variables(host)
groups = {
'all': ['centos', 'debian'],
'g': ['debian'],
'ungrouped': ['centos'],
}
assert get_vars("debian") == {
'a': 'b',
'c': 'd',
'x': 'z',
'inventory_hostname': 'debian',
'group_names': ['g'],
'groups': groups,
}
assert get_vars("centos") == {
'a': 'a',
'e': 'f',
'inventory_hostname': 'centos',
'group_names': ['ungrouped'],
'groups': groups,
}


Expand Down
5 changes: 5 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def test_ansible_module(host):
assert variables["mygroupvar"] == "qux"
assert variables["inventory_hostname"] == "debian_stretch"
assert variables["group_names"] == ["testgroup"]
assert variables["groups"] == {
"all": ["debian_stretch"],
"testgroup": ["debian_stretch"],
"ungrouped": [],
}

with pytest.raises(host.ansible.AnsibleException) as excinfo:
host.ansible("command", "zzz")
Expand Down
15 changes: 10 additions & 5 deletions testinfra/utils/ansible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,21 @@ def ansible_config(self):

def get_variables(self, host):
inventory = self.inventory
# inventory_hostname, group_names and groups are for backward
# compatibility with testinfra 2.X
hostvars = inventory['_meta'].get(
'hostvars', {}).get(host, {})
hostvars.setdefault('inventory_hostname', host)
groups = []
group_names = []
groups = {}
for group in sorted(inventory):
if group in ('_meta', 'all'):
if group == "_meta":
continue
if host in inventory[group].get('hosts', []):
groups.append(group)
hostvars.setdefault('group_names', groups)
groups[group] = sorted(list(itergroup(inventory, group)))
if group != "all" and host in inventory[group].get('hosts', []):
group_names.append(group)
hostvars.setdefault('group_names', group_names)
hostvars.setdefault('groups', groups)
return hostvars

def get_host(self, host, **kwargs):
Expand Down

0 comments on commit febf654

Please sign in to comment.