From eafd5e2f79de4e6c799f1124287c36a0c8511bab Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 6 Feb 2023 20:40:00 +0530 Subject: [PATCH] Fix primitive decoder when evaluating Promise (#1432) Signed-off-by: Samhita Alla --- flytekit/core/promise.py | 22 ++++++++++----------- tests/flytekit/unit/core/test_conditions.py | 16 +++++++++++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index bef86cc9ed..935556dafd 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -70,7 +70,6 @@ def extract_value( val_type: type, flyte_literal_type: _type_models.LiteralType, ) -> _literal_models.Literal: - if isinstance(input_val, list): lt = flyte_literal_type python_type = val_type @@ -143,17 +142,16 @@ def extract_value( def get_primitive_val(prim: Primitive) -> Any: - if prim.integer: - return prim.integer - if prim.datetime: - return prim.datetime - if prim.boolean: - return prim.boolean - if prim.duration: - return prim.duration - if prim.string_value: - return prim.string_value - return prim.float_value + for value in [ + prim.integer, + prim.float_value, + prim.string_value, + prim.boolean, + prim.datetime, + prim.duration, + ]: + if value is not None: + return value class ConjunctionOps(Enum): diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index be85918b74..ca234c743b 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -71,6 +71,22 @@ def multiplier_2(my_input: float) -> float: multiplier_2(my_input=10.0) +def test_condition_else_int(): + @workflow + def multiplier_3(my_input: int) -> float: + return ( + conditional("fractions") + .if_((my_input >= 0) & (my_input < 1.0)) + .then(double(n=my_input)) + .elif_((my_input > 1.0) & (my_input < 10.0)) + .then(square(n=my_input)) + .else_() + .fail("The input must be between 0 and 10") + ) + + assert multiplier_3(my_input=0) == 0 + + def test_condition_sub_workflows(): @task def sum_div_sub(a: int, b: int) -> typing.NamedTuple("Outputs", sum=int, div=int, sub=int):