Skip to content

Commit

Permalink
Rewrite _create_input_data.
Browse files Browse the repository at this point in the history
  • Loading branch information
MImmesberger committed Dec 8, 2024
1 parent a52f8e7 commit f34c2d5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 61 deletions.
141 changes: 83 additions & 58 deletions src/_gettsim/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def compute_taxes_and_transfers( # noqa: PLR0913
targets=targets,
data=data,
)
functions_not_overridden, functions_overridden = _filter_functions_by_name(
all_functions=all_functions,
functions_not_overridden, functions_overridden = _filter_tree_by_name_list(
tree=all_functions,
qualified_names_list=tree_flatten_with_qualified_name(data)[0],
)
data = _convert_data_to_correct_types(data, functions_overridden)
Expand Down Expand Up @@ -129,19 +129,16 @@ def compute_taxes_and_transfers( # noqa: PLR0913
check_minimal_specification=check_minimal_specification,
).nodes
# Select functions that are nodes of the DAG.
_, necessary_functions = _filter_functions_by_name(functions_not_overridden, nodes)
_, necessary_functions = _filter_tree_by_name_list(functions_not_overridden, nodes)
# Round and partial parameters into functions.
processed_functions = _round_and_partial_parameters_to_functions(
necessary_functions, params, rounding
)

# Create input data.
input_data = _create_input_data(
data=data,
processed_functions=processed_functions,
# Input structure for final DAG.
input_structure = dags.dag_tree.create_input_structure_tree(
functions=processed_functions,
targets=targets,
columns_overriding_functions=columns_overriding_functions,
check_minimal_specification=check_minimal_specification,
)

# Calculate results.
Expand All @@ -152,6 +149,16 @@ def compute_taxes_and_transfers( # noqa: PLR0913
name_clashes="raise",
)

# Create input data.
input_data = _create_input_data(
data=data,
processed_functions=processed_functions,
targets=targets,
names_of_columns_overriding_functions=names_of_columns_overriding_functions,
input_structure=input_structure,
check_minimal_specification=check_minimal_specification,
)

results = tax_transfer_function(**input_data)

# Prepare results.
Expand Down Expand Up @@ -494,8 +501,8 @@ def _convert_data_to_correct_types(
return data


def _filter_functions_by_name(
functions_tree: NestedFunctionDict,
def _filter_tree_by_name_list(
tree: NestedFunctionDict | NestedDataDict,
qualified_names_list: list[str],
) -> tuple[NestedFunctionDict, NestedFunctionDict]:
"""Filter functions by name.
Expand All @@ -506,48 +513,49 @@ def _filter_functions_by_name(
Parameters
----------
functions_tree : NestedFunctionDict
tree : NestedFunctionDict | NestedDataDict
Dictionary containing functions to build the DAG.
qualified_names_list : list[str]
List of qualified names.
Returns
-------
functions_not_in_names_list : NestedFunctionDict
not_in_names_list : NestedFunctionDict
All functions except the ones that are overridden by an input column.
functions_in_names_list : NestedFunctionDict
in_names_list : NestedFunctionDict
Functions that are overridden by an input column.
"""
functions_not_in_names_list = {}
functions_in_names_list = {}
not_in_names_list = {}
in_names_list = {}

function_paths, functions_leafs, _ = tree_flatten_with_path(functions_tree)
paths, leafs, _ = tree_flatten_with_path(tree)

for name, func in zip(function_paths, functions_leafs):
for name, leaf in zip(paths, leafs):
qualified_name = "__".join(name)
if qualified_name in qualified_names_list:
functions_in_names_list = tree_update(
functions_in_names_list,
in_names_list = tree_update(
in_names_list,
name,
func,
leaf,
)
else:
functions_not_in_names_list = tree_update(
functions_not_in_names_list,
not_in_names_list = tree_update(
not_in_names_list,
name,
func,
leaf,
)

return functions_not_in_names_list, functions_in_names_list
return not_in_names_list, in_names_list


def _create_input_data(
data,
processed_functions,
targets,
columns_overriding_functions,
check_minimal_specification="ignore",
def _create_input_data( # noqa: PLR0913
data: NestedDataDict,
processed_functions: NestedFunctionDict,
targets: NestedTargetDict,
names_of_columns_overriding_functions: list[str],
input_structure: NestedInputStructureDict,
check_minimal_specification: Literal["ignore", "warn", "raise"] = "ignore",
):
"""Create input data for use in the calculation of taxes and transfers by:
Expand All @@ -556,43 +564,53 @@ def _create_input_data(
Parameters
----------
data : Dict of pandas.Series
data : NestedDataDict
Data provided by the user.
processed_functions : dict of callable
processed_functions : NestedFunctionDict
Dictionary mapping function names to callables.
targets : list of str
List of strings with names of functions whose output is actually needed by the
user.
columns_overriding_functions : str list of str
Names of columns in the data which are preferred over function defined in the
tax and transfer system.
targets : NestedTargetDict
Targets provided by the user.
names_of_columns_overriding_functions : list[str]
Names of columns in the data that override hard-coded functions.
input_structure : NestedInputStructureDict
Tree representing the input structure.
check_minimal_specification : {"ignore", "warn", "raise"}, default "ignore"
Indicator for whether checks which ensure the most minimal configuration should
be silenced, emitted as warnings or errors.
Returns
-------
input_data : Dict of numpy.array
input_data : NestedDataDict
Data which can be used to calculate taxes and transfers.
"""
# Create dag using processed functions
dag = set_up_dag(
all_functions=processed_functions,
targets=targets,
columns_overriding_functions=columns_overriding_functions,
names_of_columns_overriding_functions=names_of_columns_overriding_functions,
input_structure=input_structure,
check_minimal_specification=check_minimal_specification,
)
root_nodes = {n for n in dag.nodes if list(dag.predecessors(n)) == []}
_fail_if_root_nodes_are_missing(root_nodes, data, processed_functions)
data = _reduce_to_necessary_data(root_nodes, data, check_minimal_specification)
data_cols = tree_flatten_with_qualified_name(data)[0]
_fail_if_root_nodes_are_missing(
functions=processed_functions,
root_nodes=root_nodes,
data_cols=data_cols,
)

# Convert series to numpy arrays
data = {key: series.values for key, series in data.items()}
# Check that only necessary data is passed
unnecessary_data, input_data = _filter_tree_by_name_list(
tree=data,
qualified_names_list=root_nodes,
)
names_unnecessary_data = tree_flatten_with_qualified_name(unnecessary_data)[0]
_warn_or_raise_if_unnecessary_data(
names_unnecessary_data, check_minimal_specification
)

# Restrict to root nodes
input_data = {k: v for k, v in data.items() if k in root_nodes}
return input_data
return tree_map(lambda x: x.values, input_data)


class FunctionsAndColumnsOverlapWarning(UserWarning):
Expand Down Expand Up @@ -764,20 +782,27 @@ def _fail_if_foreign_keys_are_invalid(data: NestedDataDict) -> None:
raise ValueError(message)


def _fail_if_root_nodes_are_missing(root_nodes, data, functions):
def _fail_if_root_nodes_are_missing(
functions: NestedFunctionDict,
root_nodes: list[str],
data_cols: list[str],
) -> None:
# Identify functions that are part of the DAG, but do not depend
# on any other function
names_to_functions_dict = tree_to_dict_with_qualified_name(functions)
funcs_based_on_params_only = [
func_name
for func_name, func in functions.items()
for func_name, func in names_to_functions_dict.items()
if len(
[a for a in inspect.signature(func).parameters if not a.endswith("_params")]
)
== 0
]

missing_nodes = [
c for c in root_nodes if c not in data and c not in funcs_based_on_params_only
c
for c in root_nodes
if c not in data_cols and c not in funcs_based_on_params_only
]
if missing_nodes:
formatted = format_list_linewise(missing_nodes)
Expand All @@ -802,18 +827,18 @@ def _fail_if_data_not_dict_with_sequence_leafs_or_dataframe(data: Any) -> None:
)


def _reduce_to_necessary_data(root_nodes, data, check_minimal_specification):
def _warn_or_raise_if_unnecessary_data(
names_unnecessary_data: list[str],
check_minimal_specification: Literal["ignore", "warn", "raise"],
) -> None:
# Produce warning or fail if more than necessary data is given.
unnecessary_data = set(data) - root_nodes
formatted = format_list_linewise(unnecessary_data)
formatted = format_list_linewise(names_unnecessary_data)
message = f"The following columns in 'data' are unused.\n\n{formatted}"
if unnecessary_data and check_minimal_specification == "warn":
if names_unnecessary_data and check_minimal_specification == "warn":
warnings.warn(message, stacklevel=2)
elif unnecessary_data and check_minimal_specification == "raise":
elif names_unnecessary_data and check_minimal_specification == "raise":
raise ValueError(message)

return {k: v for k, v in data.items() if k not in unnecessary_data}


def _round_and_partial_parameters_to_functions(
functions: NestedFunctionDict,
Expand Down
6 changes: 3 additions & 3 deletions src/_gettsim_tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
_fail_if_foreign_keys_are_invalid,
_fail_if_group_variables_not_constant_within_groups,
_fail_if_pid_is_non_unique,
_filter_functions_by_name,
_filter_tree_by_name_list,
_round_and_partial_parameters_to_functions,
_use_correct_series_names,
build_data_tree,
Expand Down Expand Up @@ -865,8 +865,8 @@ def test_use_correct_series_names(input_object, expected_output):
),
],
)
def test_filter_functions_by_name(tree, names, expected_names):
result_not_in_names, result_in_names = _filter_functions_by_name(tree, names)
def test_filter_tree_by_name_list(tree, names, expected_names):
result_not_in_names, result_in_names = _filter_tree_by_name_list(tree, names)
flattened_result_not_in_names = tree_flatten_with_qualified_name(
result_not_in_names
)[0]
Expand Down

0 comments on commit f34c2d5

Please sign in to comment.