Skip to content

Commit

Permalink
Conditions testing both paths (flyteorg#326)
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored Jul 8, 2021
1 parent 3653ae6 commit b4742d1
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions cookbook/core/control_flow/run_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ def multiplier_3(my_input: float) -> float:
#
# Wondering how output values get these methods. In a workflow no output value is available to access directly. The inputs and outputs are auto-wrapped in a special object called :py:class:`flytekit.extend.Promise`.
#
# In this contrived example for ease of testing, we are creating a biased coin whose seed we can control.
@task
def coin_toss() -> bool:
def coin_toss(seed: int) -> bool:
"""
Mimic some condition checking to see if something ran correctly
"""
if random.random() < 0.5:
r = random.Random(seed)
if r.random() < 0.5:
return True
return False

Expand All @@ -148,8 +150,8 @@ def success() -> int:


@workflow
def basic_boolean_wf() -> int:
result = coin_toss()
def basic_boolean_wf(seed: int = 5) -> int:
result = coin_toss(seed=seed)
return (
conditional("test").if_(result.is_true()).then(success()).else_().then(failed())
)
Expand Down Expand Up @@ -222,11 +224,11 @@ def nested_conditions(my_input: float) -> float:
# to be the subset of outputs that all then-nodes produce. In the following example, we call square() in one condition
# and call double in another.
@task
def sum_diff(a: float, b: float) -> (float, float):
def sum_diff(a: float, b: float) -> float:
"""
sum_diff returns the sum and difference between a and b.
"""
return a + b, a - b
return a + b


# %%
Expand All @@ -241,11 +243,11 @@ def sum_diff(a: float, b: float) -> (float, float):
#
# x = 0 if m < 0 else 1
@workflow
def consume_outputs(my_input: float) -> float:
is_heads = coin_toss()
def consume_outputs(my_input: float, seed: int = 5) -> float:
is_heads = coin_toss(seed=seed)
res = (
conditional("double_or_square")
.if_(is_heads == True)
.if_(is_heads.is_true())
.then(square(n=my_input))
.else_()
.then(sum_diff(a=my_input, b=my_input))
Expand All @@ -259,4 +261,5 @@ def consume_outputs(my_input: float) -> float:
# %%
# As usual local execution does not change
if __name__ == "__main__":
print(f"consume_outputs(0.4) => {consume_outputs(my_input=0.4)}")
print(f"consume_outputs(0.4) with default seed=5. This should return output of sum_diff => {consume_outputs(my_input=0.4)}")
print(f"consume_outputs(0.4, seed=7), this should return output of square => {consume_outputs(my_input=0.4, seed=7)}")

0 comments on commit b4742d1

Please sign in to comment.