diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 2a3687c06c..12cd86d4af 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -619,7 +619,7 @@ def binding_data_from_python_std( f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task" ) - elif expected_literal_type.union_type is not None: + elif t_value is not None and expected_literal_type.union_type is not None: for i in range(len(expected_literal_type.union_type.variants)): try: lt_type = expected_literal_type.union_type.variants[i] diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 9478cc33ba..b214764c50 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -145,3 +145,18 @@ def t1(a: typing.Union[float, typing.Dict[str, int]]): t1.interface.inputs, t1.python_interface.inputs, ) + + +def test_optional_task_kwargs(): + from typing import Optional + + from flytekit import Workflow + + @task + def func(foo: Optional[int] = None): + pass + + wf = Workflow(name="test") + wf.add_entity(func, foo=None) + + wf()