Skip to content

Commit

Permalink
Implement a new Ansible Runner
Browse files Browse the repository at this point in the history
Following specification on #431

* Use ansible-inventory cli to parse inventory, group and host variables

* Ansible backend re-use existing backend (local, ssh) to run commands.
  This is a breaking change because we do not support all connection
  backends from ansible.

* The Ansible module run with the ansible cli.

* Add more tests

* Fix the skipped "encoding" test which pass now.

This fix maintainability issue with ansible and the license issue
(ansible is GPL).

Closes #431
  • Loading branch information
philpep committed Apr 29, 2019
1 parent 436a783 commit 4c27a1e
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 148 deletions.
30 changes: 27 additions & 3 deletions test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ def test_command(host):

@pytest.mark.testinfra_hosts(*HOSTS)
def test_encoding(host):
if host.backend.get_connection_type() == "ansible":
pytest.skip("ansible handle encoding himself")

# stretch image is fr_FR@ISO-8859-15
cmd = host.run("ls -l %s", "/é")
if host.backend.get_connection_type() == "docker":
Expand Down Expand Up @@ -124,6 +121,33 @@ def get_vars(host):
}


def test_ansible_get_backend():
with tempfile.NamedTemporaryFile() as f:
f.write((
b'localhost ansible_connection=local ansible_become=yes\n'
b'debian ansible_user=u ansible_become=yes\n'
b'centos ansible_connection=ssh ansible_host=127.0.0.1 '
b'ansible_port=2222\n'
))
f.flush()

def get_backend(host):
return AnsibleRunner(f.name).get_backend(host).backend
localhost = get_backend('localhost')
assert localhost.NAME == 'local'
assert localhost.sudo
debian = get_backend('debian')
assert debian.NAME == 'paramiko'
assert debian.sudo
assert debian.host.name == 'debian'
assert debian.host.user == 'u'
centos = get_backend('centos')
assert centos.NAME == 'paramiko'
assert not centos.sudo
assert centos.host.name == '127.0.0.1'
assert centos.host.port == '2222'


def test_backend_importables():
# just check that all declared backend are importable and NAME is set
# correctly
Expand Down
15 changes: 2 additions & 13 deletions testinfra/backend/ansible.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from testinfra.backend import base
from testinfra.utils.ansible_runner import AnsibleRunner
from testinfra.utils.ansible_runner import to_bytes

logger = logging.getLogger("testinfra")

Expand All @@ -39,20 +38,10 @@ def ansible_runner(self):

def run(self, command, *args, **kwargs):
command = self.get_command(command, *args)
out = self.run_ansible("shell", module_args=command)
return self.result(
out['rc'],
command,
stdout_bytes=None,
stderr_bytes=None,
stdout=out["stdout"], stderr=out["stderr"],
)

def encode(self, data):
return to_bytes(data)
return self.ansible_runner.run(self.host, command)

def run_ansible(self, module_name, module_args=None, **kwargs):
result = self.ansible_runner.run(
result = self.ansible_runner.run_module(
self.host, module_name, module_args,
**kwargs)
logger.info(
Expand Down
295 changes: 163 additions & 132 deletions testinfra/utils/ansible_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,183 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=import-error,no-name-in-module,no-member
# pylint: disable=unexpected-keyword-arg,no-value-for-parameter
# pylint: disable=arguments-differ

from __future__ import unicode_literals
from __future__ import absolute_import

import pprint


try:
import ansible
except ImportError:
raise RuntimeError(
"You must install ansible package to use the ansible backend")
import fnmatch
import json
import os
import tempfile

import ansible.cli.playbook
import ansible.constants
import ansible.executor.task_queue_manager
import ansible.inventory
import ansible.parsing.dataloader
import ansible.playbook.play
import ansible.plugins.callback
import ansible.utils.vars
import ansible.vars
from six.moves import configparser

try:
from ansible.module_utils._text import to_bytes
except ImportError:
from ansible.utils.unicode import to_bytes
import testinfra
from testinfra.utils import cached_property


__all__ = ['AnsibleRunner', 'to_bytes']
__all__ = ['AnsibleRunner']


class AnsibleRunnerBase(object):
EMPTY_INVENTORY = {
"_meta": {
"hostvars": {}
},
"all": {
"children": [
"ungrouped"
]
},
"ungrouped": {}
}
local = testinfra.get_host('local://')


def get_ansible_config():
fname = os.environ.get('ANSIBLE_CONFIG')
if not fname:
for possible in (
os.path.join(os.path.expanduser('~'), '.ansible.cfg'),
os.path.join('/', 'etc', 'ansible', 'ansible.cfg'),
):
if os.path.exists(possible):
fname = possible
break
config = configparser.ConfigParser()
if not fname:
return config
config.read(fname)
return config


def get_ansible_inventory(inventory_file):
cmd = 'ansible-inventory --list'
args = []
if inventory_file:
cmd += ' -i %s'
args += [inventory_file]
return json.loads(local.check_output(cmd, *args))


def get_backend(config, inventory, host):
if inventory == EMPTY_INVENTORY:
return testinfra.get_host('local://')
else:
hostvars = inventory['_meta'].get('hostvars', {}).get(host, {})
connection = hostvars.get('ansible_connection', 'ssh')
if connection not in ('ssh', 'local', 'docker'):
raise NotImplementedError(
'unhandled ansible_connection {}'.format(connection))
if connection == 'ssh':
connection = 'paramiko'
testinfra_host = hostvars.get('ansible_host', host)
user = hostvars.get('ansible_user')
port = hostvars.get('ansible_port')
kwargs = {}
if hostvars.get('ansible_become', False):
kwargs['sudo'] = True
if 'ansible_ssh_private_key_file' in hostvars:
kwargs['ssh_identity_file'] = hostvars[
'ansible_ssh_private_key_file']
try:
host_key_checking = config['defaults']['host_key_checking']
except KeyError:
pass
else:
if host_key_checking.lower()[:1] in ('n', 'f', '0'):
kwargs['strict_host_key_checking'] = False
spec = '{}://'.format(connection)
if user:
spec += '{}@'.format(user)
spec += testinfra_host
if port:
spec += ':{}'.format(port)
return testinfra.get_host(spec, **kwargs)


class AnsibleRunner(object):
_runners = {}

def __init__(self, host_list=None):
self.host_list = host_list
super(AnsibleRunnerBase, self).__init__()
def __init__(self, inventory_file=None):
self.inventory_file = inventory_file
self._backend_cache = {}
super(AnsibleRunner, self).__init__()

def get_hosts(self, pattern=None):
raise NotImplementedError
inventory = self.inventory
result = set()
if inventory == EMPTY_INVENTORY:
# use localhost as fallback
result.add('localhost')
else:
for group in inventory:
groupmatch = fnmatch.fnmatch(group, pattern)
for host in inventory[group].get('hosts', []):
if (groupmatch or pattern == 'all'
or fnmatch.fnmatch(host, pattern)):
result.add(host)
return sorted(result)

@cached_property
def inventory(self):
return get_ansible_inventory(self.inventory_file)

@cached_property
def ansible_config(self):
return get_ansible_config()

def get_variables(self, host):
raise NotImplementedError

def run(self, host, module_name, module_args, **kwargs):
raise NotImplementedError
inventory = self.inventory
hostvars = inventory['_meta'].get(
'hostvars', {}).get(host, {})
hostvars.setdefault('inventory_hostname', host)
groups = []
for group in sorted(inventory):
if group in ('_meta', 'all'):
continue
if host in inventory[group].get('hosts', []):
groups.append(group)
hostvars.setdefault('group_names', groups)
return hostvars

def get_backend(self, host):
try:
return self._backend_cache[host]
except KeyError:
backend = self._backend_cache[host] = get_backend(
self.ansible_config, self.inventory, host)
return backend

def run(self, host, command):
return self.get_backend(host).run(command)

def run_module(self, host, module_name, module_args, become=False,
check=True, **kwargs):
cmd, args = 'ansible --tree %s', [None]
if self.inventory_file:
cmd += ' -i %s'
args += [self.inventory_file]
cmd += ' -m %s'
args += [module_name]
if module_args:
cmd += ' --args %s'
args += [module_args]
if become:
cmd += ' --become'
if check:
cmd += ' --check'
cmd += ' %s'
args += [host]
with tempfile.TemporaryDirectory() as d:
args[0] = d
out = local.run_expect([0, 2], cmd, *args)
files = os.listdir(d)
if not files and 'skipped' in out.stdout.lower():
return {'failed': True, 'skipped': True,
'msg': 'Skipped. You might want to try check=False'}
elif not files:
raise RuntimeError('Error while running {}: {}'.format(
' '.join(cmd), out))
with open(os.path.join(d, files[0]), 'r') as f:
return json.load(f)

@classmethod
def get_runner(cls, inventory):
Expand All @@ -69,99 +196,3 @@ def get_runner(cls, inventory):
except KeyError:
cls._runners[inventory] = cls(inventory)
return cls._runners[inventory]


class Callback(ansible.plugins.callback.CallbackBase):

def __init__(self, *args, **kwargs):
self.result = {}
super(Callback, self).__init__(*args, **kwargs)

def runner_on_ok(self, host, result):
self.result = result

def runner_on_failed(self, host, result, ignore_errors=False):
self.result = result

# pylint: disable=no-self-use
def runner_on_unreachable(self, host, result):
raise RuntimeError(
'Host {} is unreachable: {}'.format(
host, pprint.pformat(result)),
)

def runner_on_skipped(self, host, item=None):
self.result = {
'failed': True,
'msg': 'Skipped. You might want to try check=False',
'item': item,
}


class AnsibleRunner(AnsibleRunnerBase):

def __init__(self, host_list=None):
super(AnsibleRunner, self).__init__(host_list)
self.cli = ansible.cli.playbook.PlaybookCLI(None)
self.cli.options = self.cli.base_parser(
connect_opts=True,
meta_opts=True,
runas_opts=True,
subset_opts=True,
check_opts=True,
inventory_opts=True,
runtask_opts=True,
vault_opts=True,
fork_opts=True,
module_opts=True,
).parse_args([])[0]
self.cli.normalize_become_options()
self.cli.options.connection = "smart"
self.cli.options.inventory = host_list
# pylint: disable=protected-access
self.loader, self.inventory, self.variable_manager = (
self.cli._play_prereqs(self.cli.options))

def get_hosts(self, pattern=None):
return [
e.name for e in
self.inventory.get_hosts(pattern=pattern or "all")
]

def get_variables(self, host):
host = self.inventory.get_host(host)
return self.variable_manager.get_vars(host=host)

def run(self, host, module_name, module_args=None, **kwargs):
self.cli.options.check = kwargs.get("check", False)
self.cli.options.become = kwargs.get("become", False)
action = {"module": module_name}
if module_args is not None:
if module_name in ("command", "shell"):
# Workaround https://github.com/ansible/ansible/issues/13862
module_args = module_args.replace("=", "\\=")
action["args"] = module_args
play = ansible.playbook.play.Play().load({
"hosts": host,
"gather_facts": "no",
"tasks": [{
"action": action,
}],
}, variable_manager=self.variable_manager, loader=self.loader)
tqm = None
callback = Callback()
try:
tqm = ansible.executor.task_queue_manager.TaskQueueManager(
inventory=self.inventory,
variable_manager=self.variable_manager,
loader=self.loader,
options=self.cli.options,
passwords=None,
stdout_callback=callback,
)
tqm.run(play)
finally:
if tqm is not None:
tqm.cleanup()

return callback.result

0 comments on commit 4c27a1e

Please sign in to comment.