Skip to content

Commit

Permalink
feat(sdk): support dsl.If, dsl.Elif, and dsl.Else (#9894)
Browse files Browse the repository at this point in the history
* support if/elif/else

* deprecate dsl.Condition

* alter rebase

* update release notes

* address review feedback

* change BinaryOperation to ConditionOperation
  • Loading branch information
connor-mccarthy authored Sep 11, 2023
1 parent 1791818 commit c6b236d
Show file tree
Hide file tree
Showing 16 changed files with 2,183 additions and 65 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


## Features
* Add support for `dsl.If`, `dsl.Elif`, and `dsl.Else` control flow context managers; deprecate `dsl.Condition` in favor of `dsl.If` [\#9894](https://github.com/kubeflow/pipelines/pull/9894)

## Breaking changes

Expand Down
326 changes: 326 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from kfp.dsl import OutputPath
from kfp.dsl import pipeline_task
from kfp.dsl import PipelineTaskFinalStatus
from kfp.dsl import tasks_group
from kfp.dsl import yaml_component
from kfp.dsl.types import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2
Expand Down Expand Up @@ -4161,5 +4162,330 @@ def my_pipeline(
'Component output artifact.')


@dsl.component
def flip_coin() -> str:
import random
return 'heads' if random.randint(0, 1) == 0 else 'tails'


@dsl.component
def print_and_return(text: str) -> str:
print(text)
return text


@dsl.component
def flip_three_sided_coin() -> str:
import random
val = random.randint(0, 2)

if val == 0:
return 'heads'
elif val == 1:
return 'tails'
else:
return 'draw'


@dsl.component
def int_zero_through_three() -> int:
import random
return random.randint(0, 3)


class TestConditionLogic(unittest.TestCase):

def test_if(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)

def test_if_else(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.Else():
print_and_return(text='Got tails!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads')"
)

def test_if_elif_else(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_three_sided_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.Elif(flip_coin_task.output == 'tails'):
print_and_return(text='Got tails!')
with dsl.Else():
print_and_return(text='Draw!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'heads'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'heads') && inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'tails'"
)

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-3']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'heads') && !(inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'tails')"
)

def test_if_multiple_elif_else(self):

@dsl.pipeline
def int_to_string():
int_task = int_zero_through_three()
with dsl.If(int_task.output == 0):
print_and_return(text='Got zero!')
with dsl.Elif(int_task.output == 1):
print_and_return(text='Got one!')
with dsl.Elif(int_task.output == 2):
print_and_return(text='Got two!')
with dsl.Else():
print_and_return(text='Got three!')

self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0"
)
self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1"
)
self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-3']
.trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2"
)
self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-4']
.trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2)"
)

def test_nested_if_elif_else_with_pipeline_param(self):

@dsl.pipeline
def flip_coin_pipeline(confirm: bool):
int_task = int_zero_through_three()
heads_task = flip_coin()

with dsl.If(heads_task.output == 'heads'):
with dsl.If(int_task.output == 0):
print_and_return(text='Got zero!')

with dsl.Elif(int_task.output == 1):
task = print_and_return(text='Got one!')
with dsl.If(confirm == True):
print_and_return(text='Confirmed: definitely got one.')

with dsl.Elif(int_task.output == 2):
print_and_return(text='Got two!')

with dsl.Else():
print_and_return(text='Got three!')

# top level conditions
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)
# second level nested conditions
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-2'].trigger_policy.condition,
"int(inputs.parameter_values[\'pipelinechannel--int-zero-through-three-Output\']) == 0"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-3'].trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-5'].trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-6'].trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2)"
)
# third level nested conditions
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-3'].dag
.tasks['condition-4'].trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--confirm'] == true")

def test_multiple_ifs_permitted(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.If(flip_coin_task.output == 'tails'):
print_and_return(text='Got tails!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'tails'"
)

def test_multiple_else_not_permitted(self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
r'Cannot use dsl\.Else following another dsl\.Else\. dsl\.Else can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.Else():
print_and_return(text='Got tails!')
with dsl.Else():
print_and_return(text='Got tails!')

def test_else_no_if_not_supported(self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
r'dsl\.Else can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
with dsl.Else():
print_and_return(text='Got unknown')

def test_elif_no_if_not_supported(self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
r'dsl\.Elif can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.Elif(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')

def test_boolean_condition_has_helpful_error(self):
with self.assertRaisesRegex(
ValueError,
r'Got constant boolean True as a condition\. This is likely because the provided condition evaluated immediately\. At least one of the operands must be an output from an upstream task or a pipeline parameter\.'
):

@dsl.pipeline
def my_pipeline():
with dsl.Condition('foo' == 'foo'):
print_and_return(text='I will always run.')

def test_boolean_elif_has_helpful_error(self):
with self.assertRaisesRegex(
ValueError,
r'Got constant boolean False as a condition\. This is likely because the provided condition evaluated immediately\. At least one of the operands must be an output from an upstream task or a pipeline parameter\.'
):

@dsl.pipeline
def my_pipeline(text: str):
with dsl.If(text == 'foo'):
print_and_return(text='I will always run.')
with dsl.Elif('foo' == 'bar'):
print_and_return(text='I will never run.')

def test_tasks_instantiated_between_if_else_and_elif_permitted(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads on coin one!')

flip_coin_task_2 = flip_coin()

with dsl.Elif(flip_coin_task_2.output == 'tails'):
print_and_return(text='Got heads on coin two!')

flip_coin_task_3 = flip_coin()

with dsl.Else():
print_and_return(
text=f'Coin three result: {flip_coin_task_3.output}')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads') && inputs.parameter_values['pipelinechannel--flip-coin-2-Output'] == 'tails'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-3']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads') && !(inputs.parameter_values['pipelinechannel--flip-coin-2-Output'] == 'tails')"
)

def test_other_control_flow_instantiated_between_if_else_not_permitted(
self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
'dsl\.Else can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.ParallelFor(['foo', 'bar']) as item:
print_and_return(text=item)
with dsl.Else():
print_and_return(text='Got tails!')


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit c6b236d

Please sign in to comment.