From cd8611c93c6d6fc5493d2cfa7acc08d34d4158a5 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Mon, 30 May 2022 15:09:05 -0700 Subject: [PATCH] Fixes case where optional user inputs broke computation The execute function gets all upstream nodes of the required node to compute. This will mean that there will likely be "user input" nodes to cycle through. When we were computing the DFS value for them, we would assume they were required. To illustrate, if you had a function that had optional input, `baz` e.g. ```python def foo(bar: int, baz: float = 1.0) -> float: ``` This meant that if you did not pass in a value for `baz`, and `baz` was a user input, Hamilton would complain that a required node was not provided. Even though it was not required for computation. So to fix that, in execute any user node is now marked with `optional`. I believe this is fine to do, because if this is not the case, there will be a node in the graph that will have `baz` as a REQUIRED dependency, and thus things will break appropriately. To help with that, I also fixed and added some unit tests. One unit test is to ensure that we don't remove passing in `None` values as part of the kwargs to the function. Since that's what we do now, and this was another way to fix this bug, which I think would be the wrong solution. Otherwise I added tests to ensure that node order does not change the result too. --- examples/modin/original_my_script.py | 32 +++++++++++ hamilton/graph.py | 6 +- tests/resources/optional_dependencies.py | 17 ++++++ tests/test_graph.py | 73 +++++++++++++++++++++++- 4 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 examples/modin/original_my_script.py diff --git a/examples/modin/original_my_script.py b/examples/modin/original_my_script.py new file mode 100644 index 00000000..85ae5059 --- /dev/null +++ b/examples/modin/original_my_script.py @@ -0,0 +1,32 @@ +import importlib +import logging +import sys + +import pandas as pd +from hamilton import driver + +logging.basicConfig(stream=sys.stdout) +initial_columns = { # load from actuals or wherever -- this is our initial data we use as input. + # Note: these values don't have to be all series, they could be a scalar. + 'signups': pd.Series([1, 10, 50, 100, 200, 400]), + 'spend': pd.Series([10, 10, 20, 40, 40, 50]), +} +# we need to tell hamilton where to load function definitions from +module_name = 'my_functions' +module = importlib.import_module(module_name) +dr = driver.Driver(initial_columns, module) # can pass in multiple modules +# we need to specify what we want in the final dataframe. +output_columns = [ + 'spend', + 'signups', + 'avg_3wk_spend', + 'spend_per_signup', + 'spend_zero_mean_unit_variance' +] +# let's create the dataframe! +df = dr.execute(output_columns) +print(df.to_string()) + +# To visualize do `pip install sf-hamilton[visualization]` if you want these to work +# dr.visualize_execution(output_columns, './my_dag.dot', {}, graphviz_kwargs=dict(graph_attr={'ratio': '1'})) +# dr.display_all_functions('./my_full_dag.dot', graphviz_kwargs=dict(graph_attr={'ratio': '1'})) diff --git a/hamilton/graph.py b/hamilton/graph.py index efb02742..e071b261 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -454,7 +454,11 @@ def dfs_traverse(node: node.Node, dependency_type: DependencyType = DependencyTy computed[node.name] = value for final_var_node in nodes: - dfs_traverse(final_var_node) + dep_type = DependencyType.REQUIRED + if final_var_node.user_defined: + # from the top level, we don't know if this UserInput is required. So mark as optional. + dep_type = DependencyType.OPTIONAL + dfs_traverse(final_var_node, dep_type) return computed @staticmethod diff --git a/tests/resources/optional_dependencies.py b/tests/resources/optional_dependencies.py index 6f0d6018..ebad0273 100644 --- a/tests/resources/optional_dependencies.py +++ b/tests/resources/optional_dependencies.py @@ -22,6 +22,23 @@ def g(e: int, f: int = _F) -> int: return e + f +def i(h: int, f: int = _F) -> int: + """we will pass None to e and so don't want the code to break for the unit test.""" + if h is None: + h = 10 + return h + f + + +def none_result() -> int: + """Function to show that we don't filter out the result.""" + return None + + +def j(none_result: int, f: int = _F) -> int: + # dont use f. + return none_result + + def _do_all(a_val: int = _A, b_val: int = _B, d_val: int = _D, f_val: int = _F) -> Dict[str, Any]: c_val = c(a_val, b_val) e_val = e(c_val, d_val) diff --git a/tests/test_graph.py b/tests/test_graph.py index 542c34b0..8811be10 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -560,12 +560,83 @@ def test_optional_execute(config, inputs, overrides): Be careful adding tests with conflicting values between them. """ fg = graph.FunctionGraph(tests.resources.optional_dependencies, config=config) - results = fg.execute([fg.nodes['g']], inputs=inputs, overrides=overrides) + # we put a user input node first to ensure that order does not matter with computation order. + results = fg.execute([fg.nodes['b'], fg.nodes['g']], inputs=inputs, overrides=overrides) do_all_args = {key + '_val': val for key, val in {**config, **inputs, **overrides}.items()} expected_results = tests.resources.optional_dependencies._do_all(**do_all_args) assert results['g'] == expected_results['g'] +def test_optional_input_behavior(): + """Tests that if we request optional user inputs that are not provided, we do not break. And if they are + we do the right thing and return them. + """ + fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) + # nothing passed, so nothing returned + result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={}) + assert result == {} + # something passed, something returned via config + fg2 = graph.FunctionGraph(tests.resources.optional_dependencies, config={'a': 10}) + result = fg2.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={}) + assert result == {'a': 10} + # something passed, something returned via inputs + result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={'a': 10}, overrides={}) + assert result == {'a': 10} + # something passed, something returned via overrides + result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={'a': 10}) + assert result == {'a': 10} + + +@pytest.mark.parametrize('node_order', [ + ['f', 'h', 'i'], + ['f', 'i', 'h'], + ['i', 'f', 'h'], + ['i', 'h', 'f'], + ['h', 'f', 'i'], + ['h', 'i', 'f'], +]) +def test_user_input_breaks_if_required_missing(node_order): + """Tests that we break because `h` is required but is not passed in.""" + fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) + permutation = [ + fg.nodes[n] for n in node_order + ] + with pytest.raises(NotImplementedError): + fg.execute(permutation, inputs={}, overrides={}) + + +@pytest.mark.parametrize('node_order', [ + ['f', 'h', 'i'], + ['f', 'i', 'h'], + ['i', 'f', 'h'], + ['i', 'h', 'f'], + ['h', 'f', 'i'], + ['h', 'i', 'f'], +]) +def test_user_input_does_not_break_if_required_provided(node_order): + """Tests that things work no matter the order because `h` is required and is passed in, while `f` is optional.""" + fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': 10}) + permutation = [ + fg.nodes[n] for n in node_order + ] + result = fg.execute(permutation, inputs={}, overrides={}) + assert result == {'h': 10, 'i': 17} + + +def test_optional_donot_drop_none(): + """We do not want to drop `None` results from functions. We want to pass them through to the function. + + This is here to enshrine the current behavior. + """ + fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': None}) + # enshrine behavior that None is not removed from being passed to the function. + results = fg.execute([fg.nodes['h'], fg.nodes['i']], inputs={}, overrides={}) + assert results == {'h': None, 'i': 17} + fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={}) + results = fg.execute([fg.nodes['j'], fg.nodes['none_result'], fg.nodes['f']], inputs={}, overrides={}) + assert results == {'j': None, 'none_result': None} # f omitted cause it's optional. + + def test_optional_get_required_compile_time(): """Tests that getting required with optionals at compile time returns everything TODO -- change this to be testing a different function (compile time) than runtime.