Skip to content

Commit

Permalink
Fix bug that led to no derived functions for function arguments and a…
Browse files Browse the repository at this point in the history
…dd test.
  • Loading branch information
MImmesberger committed Dec 7, 2024
1 parent 480be80 commit 9c91483
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
1 change: 0 additions & 1 deletion src/_gettsim/policy_environment_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def add_derived_functions_to_functions_tree(
"""
names_of_columns_in_data = tree_flatten_with_qualified_name(data)[0]

# Create derived functions
(
time_conversion_functions,
Expand Down
11 changes: 9 additions & 2 deletions src/_gettsim/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from optree import tree_flatten_with_path

from _gettsim.config import SUPPORTED_GROUPINGS
from _gettsim.functions.policy_function import PolicyFunction


class KeyErrorMessage(str):
Expand Down Expand Up @@ -274,7 +275,7 @@ def format_errors_and_warnings(text, width=79):
return formatted_text


def get_names_of_arguments_without_defaults(function):
def get_names_of_arguments_without_defaults(function: PolicyFunction) -> list[str]:
"""Get argument names without defaults.
The detection of argument names also works for partialed functions.
Expand All @@ -296,7 +297,13 @@ def get_names_of_arguments_without_defaults(function):
p for p in parameters if parameters[p].default == parameters[p].empty
]

return argument_names_without_defaults
# Add namespace to argument names if not already present
qualified_argument_names_without_defaults = [
f"{function.module_name}__{p}" if "__" not in p else p
for p in argument_names_without_defaults
]

return qualified_argument_names_without_defaults


def remove_group_suffix(col):
Expand Down
38 changes: 38 additions & 0 deletions src/_gettsim_tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,51 @@
import pytest

from _gettsim.functions.policy_function import PolicyFunction
from _gettsim.shared import (
create_dict_from_list,
get_names_of_arguments_without_defaults,
merge_nested_dicts,
tree_to_dict_with_qualified_name,
tree_update,
)


def module1__module2__function_with_local_function_argument(a):
"""Function with a local function argument."""
return a


def module1__module2__function_with_global_function_argument(module1__module3__a):
"""Function with a global function argument."""
return module1__module3__a


@pytest.mark.parametrize(
"function, expected_argument_name",
[
(
PolicyFunction(
module1__module2__function_with_local_function_argument,
module_name="module1__module2",
function_name="function_with_local_function_argument",
),
"module1__module2__a",
),
(
PolicyFunction(
module1__module2__function_with_global_function_argument,
module_name="module1__module2",
function_name="function_with_global_function_argument",
),
"module1__module3__a",
),
],
)
def test_get_names_of_arguments(function, expected_argument_name):
names = get_names_of_arguments_without_defaults(function)
assert names == [expected_argument_name]


@pytest.mark.parametrize(
"tree, path, value, expected",
[
Expand Down

0 comments on commit 9c91483

Please sign in to comment.