Skip to content

Commit

Permalink
Partial in any parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
MImmesberger committed Dec 8, 2024
1 parent 72cc11c commit a52f8e7
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/_gettsim/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import dags
import networkx
import pandas as pd
from optree import tree_flatten, tree_flatten_with_path, tree_map, tree_map_with_path
from optree import (
tree_flatten,
tree_flatten_with_path,
tree_map,
tree_map_with_path,
tree_unflatten,
)

from _gettsim.config import (
DEFAULT_TARGETS,
Expand Down Expand Up @@ -841,8 +847,9 @@ def _round_and_partial_parameters_to_functions(
# Partial parameters to functions such that they disappear in the DAG.
# Note: Needs to be done after rounding such that dags recognizes partialled
# parameters.
processed_functions = {}
for name, function in functions.items():
function_leafs, tree_spec = tree_flatten(functions)
processed_functions = []
for function in function_leafs:
arguments = get_names_of_arguments_without_defaults(function)
partial_params = {
i: params[i[:-7]]
Expand All @@ -857,11 +864,11 @@ def _round_and_partial_parameters_to_functions(
if hasattr(function, "__info__"):
partial_func.__info__ = function.__info__

processed_functions[name] = partial_func
processed_functions.append(partial_func)
else:
processed_functions[name] = function
processed_functions.append(function)

return processed_functions
return tree_unflatten(tree_spec, processed_functions)


def _add_rounding_to_function(
Expand Down

0 comments on commit a52f8e7

Please sign in to comment.