diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 42144cceba..6724f243e3 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -512,18 +512,23 @@ def my_wf(a: int, b: str) -> (int, str): def test_wf1_branches_ne(): + @task + def t1(a: int) -> int: + return a + 1 + @task def t2(a: str) -> str: return a @workflow def my_wf(a: int, b: str) -> str: - return conditional("test1").if_(a != 5).then(t2(a=b)).else_().fail("Unable to choose branch") + new_a = t1(a=a) + return conditional("test1").if_(new_a != 5).then(t2(a=b)).else_().fail("Unable to choose branch") with pytest.raises(ValueError): - my_wf(a=5, b="hello") + my_wf(a=4, b="hello") - x = my_wf(a=6, b="hello") + x = my_wf(a=5, b="hello") assert x == "hello"