diff --git a/simpleflow/activity.py b/simpleflow/activity.py index ef1619fc7..b2858431b 100644 --- a/simpleflow/activity.py +++ b/simpleflow/activity.py @@ -1,3 +1,7 @@ +import functools +import inspect +import subprocess + from . import task @@ -30,6 +34,76 @@ def wrap(func): return wrap +class RequiredArgument(object): + pass + + +def format_arguments(*args, **kwargs): + """ + Examples: + + >>> format_arguments('a', 'b', c=1, d=2) + ['--c=1', '--d=2', 'a', 'b'] + + """ + return ['--{}="{}"'.format(key, value) for key, value in + kwargs.iteritems()] + map(str, args) + + +def zip_arguments_defaults(argspec): + if not argspec.defaults: + return [] + + return zip( + argspec.args[-len(argspec.defaults):], + argspec.defaults) + + +def check_arguments(argspec, args): + # func() or func(**kwargs) or func(a=1, b=2) + if not argspec.varargs and not argspec.args and args: + raise TypeError('command does not take varargs') + + # Calling func(a, b) with func(1, 2, 3) + if (not argspec.varargs and argspec.args and + len(args) != len(argspec.args)): + raise TypeError('command takes {} arguments: {} passed'.format( + len(argspec.args), + len(args))) + + +def check_keyword_arguments(argspec, kwargs): + # func() or func(*args) or func(a, b) + if not argspec.keywords and not argspec.defaults and kwargs: + raise TypeError('command does not take keyword arguments') + + arguments_defaults = zip_arguments_defaults(argspec) + not_found = (set(name for name, value in arguments_defaults if + value is RequiredArgument) - + set(kwargs)) + # Calling func(a=1, b) with func(2) instead of func(a=0, 2) + if not_found: + raise TypeError('argument{} "{}" not found'.format( + 's' if len(not_found) > 1 else '', + ', '.join(not_found))) + + +def execute_program(path=None, argument_format=format_arguments): + def wrap_callable(func): + @functools.wraps(func) + def execute(*args, **kwargs): + check_arguments(argspec, args) + check_keyword_arguments(argspec, kwargs) + + command = path or func.func_name + return subprocess.check_output( + [command] + argument_format(*args, **kwargs)) + + argspec = inspect.getargspec(func) + return execute + return wrap_callable + + class Activity(object): def __init__(self, callable, name=None, diff --git a/tests/test_activity.py b/tests/test_activity.py new file mode 100644 index 000000000..2ed6748e1 --- /dev/null +++ b/tests/test_activity.py @@ -0,0 +1,97 @@ +import tempfile +import os.path + +import pytest + +from simpleflow import activity + + +@activity.execute_program(path='ls') +def ls_nokwargs(*args): + """ + Only accepts a variable number of positional arguments. + + """ + pass + + +def test_execute_program_no_kwargs(): + with tempfile.NamedTemporaryFile() as f: + with pytest.raises(TypeError) as exc_info: + ls_nokwargs(hide=f.name) + + assert (exc_info.value.message == + 'command does not take keyword arguments') + + +@activity.execute_program(path='ls') +def ls_noargs(**kwargs): + """ + Only accepts a variable number of keyword arguments. + + """ + pass + + +def test_execute_program_no_args(): + with tempfile.NamedTemporaryFile() as f: + with pytest.raises(TypeError) as exc_info: + ls_noargs(f.name) + + assert (exc_info.value.message == + 'command does not take varargs') + + +@activity.execute_program(path='ls') +def ls_restrict_named_arguments(hide=activity.RequiredArgument, *args): + pass + + +def test_execute_program_restrict_named_arguments(): + with tempfile.NamedTemporaryFile() as f: + with pytest.raises(TypeError) as exc_info: + ls_restrict_named_arguments(f.name) + + assert (exc_info.value.message == + 'argument "hide" not found') + + +@activity.execute_program(path='ls') +def ls_optional_named_arguments(hide='', *args): + pass + + +def test_execute_program_optional_named_arguments(): + with tempfile.NamedTemporaryFile() as f: + assert ls_optional_named_arguments(f.name).strip() == f.name + assert f.name not in ls_optional_named_arguments(hide=f.name) + + +@activity.execute_program() +def ls(*args, **kwargs): + pass + + +def test_execute_program_with_positional_arguments(): + with tempfile.NamedTemporaryFile() as f: + assert ls(f.name).strip() == f.name + + +def test_execute_program_with_named_arguments(): + with tempfile.NamedTemporaryFile() as f: + assert f.name not in (ls( + os.path.dirname(f.name), + hide=f.name).strip()) + + +@activity.execute_program() +def ls_2args(a, b): + pass + + +def test_ls_2args(): + with pytest.raises(TypeError) as exc_info: + ls_2args(1, 2, 3) + + assert (exc_info.value.message == + 'command takes 2 arguments: 3 passed')