diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 698c4df5a3..a227abf948 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Class and decorators to generate processes out of simple python functions.""" +from __future__ import annotations + import collections import functools import inspect @@ -192,6 +194,7 @@ class FunctionProcess(Process): """Function process class used for turning functions into a Process""" _func_args: Sequence[str] = () + _varargs: str | None = None @staticmethod def _func(*_args, **_kwargs) -> dict: @@ -224,9 +227,6 @@ def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['Fu ndefaults = len(defaults) if defaults else 0 first_default_pos = nargs - ndefaults - if varargs is not None: - raise ValueError('variadic arguments are not supported') - def _define(cls, spec): # pylint: disable=unused-argument """Define the spec dynamically""" from plumpy.ports import UNSPECIFIED @@ -271,8 +271,8 @@ def _define(cls, spec): # pylint: disable=unused-argument if not port_label.has_default(): port_label.default = func.__name__ - # If the function support kwargs then allow dynamic inputs, otherwise disallow - spec.inputs.dynamic = keywords is not None + # If the function supports varargs or kwargs then allow dynamic inputs, otherwise disallow + spec.inputs.dynamic = keywords is not None or varargs # Function processes must have a dynamic output namespace since we do not know beforehand what outputs # will be returned and the valid types for the value should be `Data` nodes as well as a dictionary because @@ -286,6 +286,7 @@ def _define(cls, spec): # pylint: disable=unused-argument '_func': staticmethod(func), Process.define.__name__: classmethod(_define), '_func_args': args, + '_varargs': varargs or None, '_node_class': node_class } ) @@ -299,13 +300,15 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable= """ nargs = len(args) nparameters = len(cls._func_args) + has_varargs = cls._varargs is not None # If the spec is dynamic, i.e. the function signature includes `**kwargs` and the number of positional arguments # passed is larger than the number of explicitly defined parameters in the signature, the inputs are invalid and # we should raise. If we don't, some of the passed arguments, intended to be positional arguments, will be # misinterpreted as keyword arguments, but they won't have an explicit name to use for the link label, causing - # the input link to be completely lost. - if cls.spec().inputs.dynamic and nargs > nparameters: + # the input link to be completely lost. If the function supports variadic arguments, however, additional args + # should be accepted. + if cls.spec().inputs.dynamic and nargs > nparameters and not has_varargs: name = cls._func.__name__ raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given') @@ -331,7 +334,35 @@ def args_to_dict(cls, *args: Any) -> Dict[str, Any]: :return: A label -> value dictionary """ - return dict(list(zip(cls._func_args, args))) + dictionary = {} + values = list(args) + + for arg in cls._func_args: + try: + dictionary[arg] = values.pop(0) + except IndexError: + pass + + # If arguments remain and the function supports variadic arguments, add those as well. + if cls._varargs and args: + + # By default the prefix for variadic labels is the key with which the varargs were declared + variadic_prefix = cls._varargs + + for index, arg in enumerate(values): + label = f'{variadic_prefix}_{index}' + + # If the generated vararg label overlaps with a keyword argument, function signature should be changed + if label in dictionary: + raise RuntimeError( + f'variadic argument with index `{index}` would get the label `{label}` but this is already in ' + 'use by another function argument with the exact same name. To avoid this error, please change ' + f'the name of argument `{label}` to something else.' + ) + + dictionary[label] = arg + + return dictionary @classmethod def get_or_create_db_record(cls) -> 'ProcessNode': @@ -400,7 +431,10 @@ def run(self) -> Optional['ExitCode']: try: args[self._func_args.index(name)] = value except ValueError: - kwargs[name] = value + if name.startswith(f'{self._varargs}_'): + args.append(value) + else: + kwargs[name] = value result = self._func(*args, **kwargs) diff --git a/docs/source/topics/processes/functions.rst b/docs/source/topics/processes/functions.rst index deda0601b9..25dfeb80d2 100644 --- a/docs/source/topics/processes/functions.rst +++ b/docs/source/topics/processes/functions.rst @@ -82,12 +82,8 @@ Both function calls in the example above will have the exact same result. Variable and keyword arguments ============================== -Variable arguments are *not* supported by process functions. -The reasoning behind this is that the process specification for the :py:class:`~aiida.engine.processes.functions.FunctionProcess` is built dynamically based on the function signature and so the names of the inputs are based on the parameter name from the function definition, or the named argument when the function is called. -Since for variable arguments, neither at function definition nor at function call, explicit parameter names are used, the engine can impossibly determine what names, and by extensions link label, to use for the inputs. -In contrast, keyword arguments for that reason *are* supported and it is the keyword used when the function is called that determines the names of the parameters and the labels of the input links. -The following snippet is therefore perfectly legal and will return the sum of all the nodes that are passed: +Keyword arguments can be used effectively if a process function should take a number of arguments that is unknown beforehand: .. include:: include/snippets/functions/signature_calcfunction_kwargs.py :code: python @@ -103,6 +99,24 @@ Note that the inputs **have to be passed as keyword arguments** because they are If the inputs would simply have been passed as positional arguments, the engine could have impossibly determined what label to use for the links that connect the input nodes with the calculation function node. For this reason, invoking a 'dynamic' function, i.e. one that supports ``**kwargs`` in its signature, with more positional arguments that explicitly named in the signature, will raise a ``TypeError``. +.. versionadded:: 2.3 + + Variable arguments are now supported. + +Variable arguments can be used in case the function should accept a list of inputs of unknown length. +Consider the example of a calculation function that computes the average of a number of ``Int`` nodes: + +.. include:: include/snippets/functions/signature_calcfunction_args.py + :code: python + +The result will be a ``Float`` node with value ``2``. +Since in this example the arguments are not explicitly declared in the function signature, nor are their values passed with a keyword in the function invocation, AiiDA needs to come up with a different way to determine the labels to link the input nodes to the calculation. +For variadic arguments, link labels are created from the variable argument declaration (``*args`` in the example), followed by an index. +The link labels for the example above will therefore be ``args_0``, ``args_1`` and ``args_2``. +If any of these labels were to overlap with the label of a positional or keyword argument, a ``RuntimeError`` will be raised. +In this case, the conflicting argument name needs to be changed to something that does not overlap with the automatically generated labels for the variadic arguments. + + Return values ============= In :numref:`fig_calculation_functions_kwargs` you can see that the engine used the label ``result`` for the link connecting the calculation function node with its output node. diff --git a/docs/source/topics/processes/include/snippets/functions/signature_calcfunction_args.py b/docs/source/topics/processes/include/snippets/functions/signature_calcfunction_args.py new file mode 100644 index 0000000000..4d3536797c --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/signature_calcfunction_args.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +from aiida.engine import calcfunction +from aiida.orm import Int + + +@calcfunction +def average(*args): + return sum(args) / len(args) + +result = average(*(Int(1), Int(2), Int(3))) diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index 273bc88d53..9bf6d90fc3 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -40,6 +40,16 @@ def function_return_input(data): return data +@calcfunction +def function_variadic_arguments(int_a, int_b, *args): + return int_a + int_b + orm.Int(sum(args)) + + +@calcfunction +def function_variadic_arguments_label_overlap(args_0, *args): + return args_0 + orm.Int(sum(args)) + + @calcfunction def function_return_true(): return get_true_node() @@ -58,8 +68,8 @@ def function_args_with_default(data_a=lambda: orm.Int(DEFAULT_INT)): @calcfunction def function_with_none_default(int_a, int_b, int_c=None): if int_c is not None: - return orm.Int(int_a + int_b + int_c) - return orm.Int(int_a + int_b) + return int_a + int_b + int_c + return int_a + int_b @workfunction @@ -203,12 +213,27 @@ def test_get_function_source_code(): def test_function_varargs(): - """Variadic arguments are not supported and should raise.""" - with pytest.raises(ValueError): + """Test a function with variadic arguments.""" + result, node = function_variadic_arguments.run_get_node(orm.Int(1), orm.Int(2), *(orm.Int(3), orm.Int(4))) + assert isinstance(result, orm.Int) + assert result.value == 10 + + inputs = node.get_incoming().nested() + assert inputs['int_a'].value == 1 + assert inputs['int_b'].value == 2 + assert inputs['args_0'].value == 3 + assert inputs['args_1'].value == 4 + assert node.inputs.args_0.value == 3 + assert node.inputs.args_1.value == 4 - @workfunction - def function_varargs(*args): # pylint: disable=unused-variable - return args + +def test_function_varargs_label_overlap(): + """Test a function with variadic arguments where the automatic label overlaps with a declared argument. + + This should raise a ``RuntimeError``. + """ + with pytest.raises(RuntimeError, match=r'variadic argument with index `.*` would get the label `.*` but this'): + function_variadic_arguments_label_overlap.run_get_node(orm.Int(1), *(orm.Int(2), orm.Int(3))) def test_function_args():