Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Fixes case where optional user inputs broke computation
Browse files Browse the repository at this point in the history
The execute function gets all upstream nodes of the required node to compute.
This will mean that there will likely be "user input" nodes to cycle through. When
we were computing the DFS value for them, we would assume they were required.

To illustrate, if you had a function that had optional input, `baz` e.g.

```python

def foo(bar: int, baz: float = 1.0) -> float:
```
This meant that if you did not pass in a value for `baz`, and `baz`
was a user input, Hamilton would complain that a required node
was not provided. Even though it was not required for computation.

So to fix that, in execute any user node is now marked with `optional`.
I believe this is fine to do, because if this is not the case, there will be a
node in the graph that will have `baz` as a REQUIRED dependency, and
thus things will break appropriately.

To help with that, I also fixed and added some unit tests.

One unit test is to ensure that we don't remove passing in `None` values
as part of the kwargs to the function. Since that's what we do now, and this
was another way to fix this bug, which I think would be the wrong way to go about it.

Otherwise I added tests to ensure that node order does not change the result too.
  • Loading branch information
skrawcz committed Jun 21, 2022
1 parent b5e65de commit 21c007e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
6 changes: 5 additions & 1 deletion hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,11 @@ def dfs_traverse(node: node.Node, dependency_type: DependencyType = DependencyTy
computed[node.name] = value

for final_var_node in nodes:
dfs_traverse(final_var_node)
dep_type = DependencyType.REQUIRED
if final_var_node.user_defined:
# from the top level, we don't know if this UserInput is required. So mark as optional.
dep_type = DependencyType.OPTIONAL
dfs_traverse(final_var_node, dep_type)
return computed

@staticmethod
Expand Down
17 changes: 17 additions & 0 deletions tests/resources/optional_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ def g(e: int, f: int = _F) -> int:
return e + f


def i(h: int, f: int = _F) -> int:
"""we will pass None to e and so don't want the code to break for the unit test."""
if h is None:
h = 10
return h + f


def none_result() -> int:
"""Function to show that we don't filter out the result."""
return None


def j(none_result: int, f: int = _F) -> int:
# dont use f.
return none_result


def _do_all(a_val: int = _A, b_val: int = _B, d_val: int = _D, f_val: int = _F) -> Dict[str, Any]:
c_val = c(a_val, b_val)
e_val = e(c_val, d_val)
Expand Down
60 changes: 59 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from itertools import permutations
import tempfile
import typing
import uuid
Expand Down Expand Up @@ -560,12 +561,69 @@ def test_optional_execute(config, inputs, overrides):
Be careful adding tests with conflicting values between them.
"""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config=config)
results = fg.execute([fg.nodes['g']], inputs=inputs, overrides=overrides)
# we put a user input node first to ensure that order does not matter with computation order.
results = fg.execute([fg.nodes['b'], fg.nodes['g']], inputs=inputs, overrides=overrides)
do_all_args = {key + '_val': val for key, val in {**config, **inputs, **overrides}.items()}
expected_results = tests.resources.optional_dependencies._do_all(**do_all_args)
assert results['g'] == expected_results['g']


def test_optional_input_behavior():
"""Tests that if we request optional user inputs that are not provided, we do not break. And if they are
we do the right thing and return them.
"""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
# nothing passed, so nothing returned
result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={})
assert result == {}
# something passed, something returned via config
fg2 = graph.FunctionGraph(tests.resources.optional_dependencies, config={'a': 10})
result = fg2.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={})
assert result == {'a': 10}
# something passed, something returned via inputs
result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={'a': 10}, overrides={})
assert result == {'a': 10}
# something passed, something returned via overrides
result = fg.execute([fg.nodes['b'], fg.nodes['a']], inputs={}, overrides={'a': 10})
assert result == {'a': 10}


@pytest.mark.parametrize('node_order', list(permutations('fhi')))
def test_user_input_breaks_if_required_missing(node_order):
"""Tests that we break because `h` is required but is not passed in."""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
permutation = [
fg.nodes[n] for n in node_order
]
with pytest.raises(NotImplementedError):
fg.execute(permutation, inputs={}, overrides={})


@pytest.mark.parametrize('node_order', list(permutations('fhi')))
def test_user_input_does_not_break_if_required_provided(node_order):
"""Tests that things work no matter the order because `h` is required and is passed in, while `f` is optional."""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': 10})
permutation = [
fg.nodes[n] for n in node_order
]
result = fg.execute(permutation, inputs={}, overrides={})
assert result == {'h': 10, 'i': 17}


def test_optional_donot_drop_none():
"""We do not want to drop `None` results from functions. We want to pass them through to the function.
This is here to enshrine the current behavior.
"""
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={'h': None})
# enshrine behavior that None is not removed from being passed to the function.
results = fg.execute([fg.nodes['h'], fg.nodes['i']], inputs={}, overrides={})
assert results == {'h': None, 'i': 17}
fg = graph.FunctionGraph(tests.resources.optional_dependencies, config={})
results = fg.execute([fg.nodes['j'], fg.nodes['none_result'], fg.nodes['f']], inputs={}, overrides={})
assert results == {'j': None, 'none_result': None} # f omitted cause it's optional.


def test_optional_get_required_compile_time():
"""Tests that getting required with optionals at compile time returns everything
TODO -- change this to be testing a different function (compile time) than runtime.
Expand Down

0 comments on commit 21c007e

Please sign in to comment.