Skip to content

Commit

Permalink
ProcessFunction: Add support for variadic arguments
Browse files Browse the repository at this point in the history
Up till now, variadic arguments, i.e., arguments defined as `*args` in a
function signature which collects any remaining positional arguments,
were not supported for process functions. The main reason was that it
wasn't immediately clear what the link label should be for these inputs.

For normal positional arguments we can take the name of the argument
declaration in the function signature, and for keyword arguments we take
the keyword with which the argument is passed in the function invocation.
But for variadic arguments there is no specific argument name, not in
the function signature, nor in the function invocation. However, we can
simply create a link label. We just have to ensure that it doesn't clash
with the link labels that will be generated for the positional and
keyword arguments.

Here the link label will be determined with the following format:

    `{label_prefix}_{index}`

The `label_prefix` is determined the name of the variadic argument. If a
function is declared as `function(a, *args, **kwargs)` the prefix will
be equal to `args` and if it is `function(*some_var_args)` it will be
`some_var_args`. The index will simply be the zero-base index of the
argument within the variadic arguments tuple. This would therefore give
link labels `args_0`, `args_1` etc. in the first example.

If there would be a clash of labels, for example with the function def:

    def function(args_0, *args):

which when invoked as:

    function(1, *(2, 3))

would generate the labels `args_0` for the first positional argument,
but also `args_0` for the first variadic argument. This clash is
detected and a `RuntimeError` is raised instructing the user to fix it.
  • Loading branch information
sphuber committed Dec 14, 2022
1 parent 6608e0d commit 36118c8
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 21 deletions.
52 changes: 43 additions & 9 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Class and decorators to generate processes out of simple python functions."""
from __future__ import annotations

import collections
import functools
import inspect
Expand Down Expand Up @@ -192,6 +194,7 @@ class FunctionProcess(Process):
"""Function process class used for turning functions into a Process"""

_func_args: Sequence[str] = ()
_varargs: str | None = None

@staticmethod
def _func(*_args, **_kwargs) -> dict:
Expand Down Expand Up @@ -224,9 +227,6 @@ def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['Fu
ndefaults = len(defaults) if defaults else 0
first_default_pos = nargs - ndefaults

if varargs is not None:
raise ValueError('variadic arguments are not supported')

def _define(cls, spec): # pylint: disable=unused-argument
"""Define the spec dynamically"""
from plumpy.ports import UNSPECIFIED
Expand Down Expand Up @@ -271,8 +271,8 @@ def _define(cls, spec): # pylint: disable=unused-argument
if not port_label.has_default():
port_label.default = func.__name__

# If the function support kwargs then allow dynamic inputs, otherwise disallow
spec.inputs.dynamic = keywords is not None
# If the function supports varargs or kwargs then allow dynamic inputs, otherwise disallow
spec.inputs.dynamic = keywords is not None or varargs

# Function processes must have a dynamic output namespace since we do not know beforehand what outputs
# will be returned and the valid types for the value should be `Data` nodes as well as a dictionary because
Expand All @@ -286,6 +286,7 @@ def _define(cls, spec): # pylint: disable=unused-argument
'_func': staticmethod(func),
Process.define.__name__: classmethod(_define),
'_func_args': args,
'_varargs': varargs or None,
'_node_class': node_class
}
)
Expand All @@ -299,13 +300,15 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=
"""
nargs = len(args)
nparameters = len(cls._func_args)
has_varargs = cls._varargs is not None

# If the spec is dynamic, i.e. the function signature includes `**kwargs` and the number of positional arguments
# passed is larger than the number of explicitly defined parameters in the signature, the inputs are invalid and
# we should raise. If we don't, some of the passed arguments, intended to be positional arguments, will be
# misinterpreted as keyword arguments, but they won't have an explicit name to use for the link label, causing
# the input link to be completely lost.
if cls.spec().inputs.dynamic and nargs > nparameters:
# the input link to be completely lost. If the function supports variadic arguments, however, additional args
# should be accepted.
if cls.spec().inputs.dynamic and nargs > nparameters and not has_varargs:
name = cls._func.__name__
raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given')

Expand All @@ -331,7 +334,35 @@ def args_to_dict(cls, *args: Any) -> Dict[str, Any]:
:return: A label -> value dictionary
"""
return dict(list(zip(cls._func_args, args)))
dictionary = {}
values = list(args)

for arg in cls._func_args:
try:
dictionary[arg] = values.pop(0)
except IndexError:
pass

# If arguments remain and the function supports variadic arguments, add those as well.
if cls._varargs and args:

# By default the prefix for variadic labels is the key with which the varargs were declared
variadic_prefix = cls._varargs

for index, arg in enumerate(values):
label = f'{variadic_prefix}_{index}'

# If the generated vararg label overlaps with a keyword argument, function signature should be changed
if label in dictionary:
raise RuntimeError(
f'variadic argument with index `{index}` would get the label `{label}` but this is already in '
'use by another function argument with the exact same name. To avoid this error, please change '
f'the name of argument `{label}` to something else.'
)

dictionary[label] = arg

return dictionary

@classmethod
def get_or_create_db_record(cls) -> 'ProcessNode':
Expand Down Expand Up @@ -400,7 +431,10 @@ def run(self) -> Optional['ExitCode']:
try:
args[self._func_args.index(name)] = value
except ValueError:
kwargs[name] = value
if name.startswith(f'{self._varargs}_'):
args.append(value)
else:
kwargs[name] = value

result = self._func(*args, **kwargs)

Expand Down
24 changes: 19 additions & 5 deletions docs/source/topics/processes/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,8 @@ Both function calls in the example above will have the exact same result.

Variable and keyword arguments
==============================
Variable arguments are *not* supported by process functions.
The reasoning behind this is that the process specification for the :py:class:`~aiida.engine.processes.functions.FunctionProcess` is built dynamically based on the function signature and so the names of the inputs are based on the parameter name from the function definition, or the named argument when the function is called.
Since for variable arguments, neither at function definition nor at function call, explicit parameter names are used, the engine can impossibly determine what names, and by extensions link label, to use for the inputs.

In contrast, keyword arguments for that reason *are* supported and it is the keyword used when the function is called that determines the names of the parameters and the labels of the input links.
The following snippet is therefore perfectly legal and will return the sum of all the nodes that are passed:
Keyword arguments can be used effectively if a process function should take a number of arguments that is unknown beforehand:

.. include:: include/snippets/functions/signature_calcfunction_kwargs.py
:code: python
Expand All @@ -103,6 +99,24 @@ Note that the inputs **have to be passed as keyword arguments** because they are
If the inputs would simply have been passed as positional arguments, the engine could have impossibly determined what label to use for the links that connect the input nodes with the calculation function node.
For this reason, invoking a 'dynamic' function, i.e. one that supports ``**kwargs`` in its signature, with more positional arguments that explicitly named in the signature, will raise a ``TypeError``.

.. versionadded:: 2.3

Variable arguments are now supported.

Variable arguments can be used in case the function should accept a list of inputs of unknown length.
Consider the example of a calculation function that computes the average of a number of ``Int`` nodes:

.. include:: include/snippets/functions/signature_calcfunction_args.py
:code: python

The result will be a ``Float`` node with value ``2``.
Since in this example the arguments are not explicitly declared in the function signature, nor are their values passed with a keyword in the function invocation, AiiDA needs to come up with a different way to determine the labels to link the input nodes to the calculation.
For variadic arguments, link labels are created from the variable argument declaration (``*args`` in the example), followed by an index.
The link labels for the example above will therefore be ``args_0``, ``args_1`` and ``args_2``.
If any of these labels were to overlap with the label of a positional or keyword argument, a ``RuntimeError`` will be raised.
In this case, the conflicting argument name needs to be changed to something that does not overlap with the automatically generated labels for the variadic arguments.


Return values
=============
In :numref:`fig_calculation_functions_kwargs` you can see that the engine used the label ``result`` for the link connecting the calculation function node with its output node.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from aiida.engine import calcfunction
from aiida.orm import Int


@calcfunction
def average(*args):
return sum(args) / len(args)

result = average(*(Int(1), Int(2), Int(3)))
39 changes: 32 additions & 7 deletions tests/engine/test_process_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ def function_return_input(data):
return data


@calcfunction
def function_variadic_arguments(int_a, int_b, *args):
return int_a + int_b + orm.Int(sum(args))


@calcfunction
def function_variadic_arguments_label_overlap(args_0, *args):
return args_0 + orm.Int(sum(args))


@calcfunction
def function_return_true():
return get_true_node()
Expand All @@ -58,8 +68,8 @@ def function_args_with_default(data_a=lambda: orm.Int(DEFAULT_INT)):
@calcfunction
def function_with_none_default(int_a, int_b, int_c=None):
if int_c is not None:
return orm.Int(int_a + int_b + int_c)
return orm.Int(int_a + int_b)
return int_a + int_b + int_c
return int_a + int_b


@workfunction
Expand Down Expand Up @@ -203,12 +213,27 @@ def test_get_function_source_code():


def test_function_varargs():
"""Variadic arguments are not supported and should raise."""
with pytest.raises(ValueError):
"""Test a function with variadic arguments."""
result, node = function_variadic_arguments.run_get_node(orm.Int(1), orm.Int(2), *(orm.Int(3), orm.Int(4)))
assert isinstance(result, orm.Int)
assert result.value == 10

inputs = node.get_incoming().nested()
assert inputs['int_a'].value == 1
assert inputs['int_b'].value == 2
assert inputs['args_0'].value == 3
assert inputs['args_1'].value == 4
assert node.inputs.args_0.value == 3
assert node.inputs.args_1.value == 4

@workfunction
def function_varargs(*args): # pylint: disable=unused-variable
return args

def test_function_varargs_label_overlap():
"""Test a function with variadic arguments where the automatic label overlaps with a declared argument.
This should raise a ``RuntimeError``.
"""
with pytest.raises(RuntimeError, match=r'variadic argument with index `.*` would get the label `.*` but this'):
function_variadic_arguments_label_overlap.run_get_node(orm.Int(1), *(orm.Int(2), orm.Int(3)))


def test_function_args():
Expand Down

0 comments on commit 36118c8

Please sign in to comment.