diff --git a/src/_gettsim/interface.py b/src/_gettsim/interface.py index 37e56634c..d9da8188f 100644 --- a/src/_gettsim/interface.py +++ b/src/_gettsim/interface.py @@ -8,7 +8,13 @@ import dags import networkx import pandas as pd -from optree import tree_flatten, tree_flatten_with_path, tree_map, tree_map_with_path +from optree import ( + tree_flatten, + tree_flatten_with_path, + tree_map, + tree_map_with_path, + tree_unflatten, +) from _gettsim.config import ( DEFAULT_TARGETS, @@ -841,8 +847,9 @@ def _round_and_partial_parameters_to_functions( # Partial parameters to functions such that they disappear in the DAG. # Note: Needs to be done after rounding such that dags recognizes partialled # parameters. - processed_functions = {} - for name, function in functions.items(): + function_leafs, tree_spec = tree_flatten(functions) + processed_functions = [] + for function in function_leafs: arguments = get_names_of_arguments_without_defaults(function) partial_params = { i: params[i[:-7]] @@ -857,11 +864,11 @@ def _round_and_partial_parameters_to_functions( if hasattr(function, "__info__"): partial_func.__info__ = function.__info__ - processed_functions[name] = partial_func + processed_functions.append(partial_func) else: - processed_functions[name] = function + processed_functions.append(function) - return processed_functions + return tree_unflatten(tree_spec, processed_functions) def _add_rounding_to_function(