Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): support dsl.If, dsl.Elif, and dsl.Else #9894

Merged
merged 6 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


## Features
* Add support for `dsl.If`, `dsl.Elif`, and `dsl.Else` control flow context managers; deprecate `dsl.Condition` in favor of `dsl.If` [\#9894](https://github.com/kubeflow/pipelines/pull/9894)

## Breaking changes

Expand Down
326 changes: 326 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from kfp.dsl import OutputPath
from kfp.dsl import pipeline_task
from kfp.dsl import PipelineTaskFinalStatus
from kfp.dsl import tasks_group
from kfp.dsl import yaml_component
from kfp.dsl.types import type_utils
from kfp.pipeline_spec import pipeline_spec_pb2
Expand Down Expand Up @@ -4161,5 +4162,330 @@ def my_pipeline(
'Component output artifact.')


@dsl.component
def flip_coin() -> str:
import random
return 'heads' if random.randint(0, 1) == 0 else 'tails'


@dsl.component
def print_and_return(text: str) -> str:
print(text)
return text


@dsl.component
def flip_three_sided_coin() -> str:
import random
val = random.randint(0, 2)

if val == 0:
return 'heads'
elif val == 1:
return 'tails'
else:
return 'draw'


@dsl.component
def int_zero_through_three() -> int:
import random
return random.randint(0, 3)


class TestConditionLogic(unittest.TestCase):

def test_if(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)

def test_if_else(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.Else():
print_and_return(text='Got tails!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads')"
)

def test_if_elif_else(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_three_sided_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.Elif(flip_coin_task.output == 'tails'):
print_and_return(text='Got tails!')
with dsl.Else():
print_and_return(text='Draw!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'heads'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'heads') && inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'tails'"
)

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-3']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'heads') && !(inputs.parameter_values['pipelinechannel--flip-three-sided-coin-Output'] == 'tails')"
)

def test_if_multiple_elif_else(self):

@dsl.pipeline
def int_to_string():
int_task = int_zero_through_three()
with dsl.If(int_task.output == 0):
print_and_return(text='Got zero!')
with dsl.Elif(int_task.output == 1):
print_and_return(text='Got one!')
with dsl.Elif(int_task.output == 2):
print_and_return(text='Got two!')
with dsl.Else():
print_and_return(text='Got three!')

self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0"
)
self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1"
)
self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-3']
.trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2"
)
self.assertEqual(
int_to_string.pipeline_spec.root.dag.tasks['condition-4']
.trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2)"
)

def test_nested_if_elif_else_with_pipeline_param(self):

@dsl.pipeline
def flip_coin_pipeline(confirm: bool):
int_task = int_zero_through_three()
heads_task = flip_coin()

with dsl.If(heads_task.output == 'heads'):
with dsl.If(int_task.output == 0):
print_and_return(text='Got zero!')

with dsl.Elif(int_task.output == 1):
task = print_and_return(text='Got one!')
with dsl.If(confirm == True):
print_and_return(text='Confirmed: definitely got one.')

with dsl.Elif(int_task.output == 2):
print_and_return(text='Got two!')

with dsl.Else():
print_and_return(text='Got three!')

# top level conditions
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)
# second level nested conditions
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-2'].trigger_policy.condition,
"int(inputs.parameter_values[\'pipelinechannel--int-zero-through-three-Output\']) == 0"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-3'].trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-5'].trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag
.tasks['condition-6'].trigger_policy.condition,
"!(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 0) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 1) && !(int(inputs.parameter_values['pipelinechannel--int-zero-through-three-Output']) == 2)"
)
# third level nested conditions
self.assertEqual(
flip_coin_pipeline.pipeline_spec.components['comp-condition-3'].dag
.tasks['condition-4'].trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--confirm'] == true")

def test_multiple_ifs_permitted(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.If(flip_coin_task.output == 'tails'):
print_and_return(text='Got tails!')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'tails'"
)

def test_multiple_else_not_permitted(self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
r'Cannot use dsl\.Else following another dsl\.Else\. dsl\.Else can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.Else():
print_and_return(text='Got tails!')
with dsl.Else():
print_and_return(text='Got tails!')

def test_else_no_if_not_supported(self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
r'dsl\.Else can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
with dsl.Else():
print_and_return(text='Got unknown')

def test_elif_no_if_not_supported(self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
r'dsl\.Elif can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.Elif(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')

def test_boolean_condition_has_helpful_error(self):
with self.assertRaisesRegex(
ValueError,
r'Got constant boolean True as a condition\. This is likely because the provided condition evaluated immediately\. At least one of the operands must be an output from an upstream task or a pipeline parameter\.'
):

@dsl.pipeline
def my_pipeline():
with dsl.Condition('foo' == 'foo'):
print_and_return(text='I will always run.')

def test_boolean_elif_has_helpful_error(self):
with self.assertRaisesRegex(
ValueError,
r'Got constant boolean False as a condition\. This is likely because the provided condition evaluated immediately\. At least one of the operands must be an output from an upstream task or a pipeline parameter\.'
):

@dsl.pipeline
def my_pipeline(text: str):
with dsl.If(text == 'foo'):
print_and_return(text='I will always run.')
with dsl.Elif('foo' == 'bar'):
print_and_return(text='I will never run.')

def test_tasks_instantiated_between_if_else_and_elif_permitted(self):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads on coin one!')

flip_coin_task_2 = flip_coin()

with dsl.Elif(flip_coin_task_2.output == 'tails'):
print_and_return(text='Got heads on coin two!')

flip_coin_task_3 = flip_coin()

with dsl.Else():
print_and_return(
text=f'Coin three result: {flip_coin_task_3.output}')

self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-1']
.trigger_policy.condition,
"inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-2']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads') && inputs.parameter_values['pipelinechannel--flip-coin-2-Output'] == 'tails'"
)
self.assertEqual(
flip_coin_pipeline.pipeline_spec.root.dag.tasks['condition-3']
.trigger_policy.condition,
"!(inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads') && !(inputs.parameter_values['pipelinechannel--flip-coin-2-Output'] == 'tails')"
)

def test_other_control_flow_instantiated_between_if_else_not_permitted(
self):
with self.assertRaisesRegex(
tasks_group.InvalidControlFlowException,
'dsl\.Else can only be used following an upstream dsl\.If or dsl\.Elif\.'
):

@dsl.pipeline
def flip_coin_pipeline():
flip_coin_task = flip_coin()
with dsl.If(flip_coin_task.output == 'heads'):
print_and_return(text='Got heads!')
with dsl.ParallelFor(['foo', 'bar']) as item:
print_and_return(text=item)
with dsl.Else():
print_and_return(text='Got tails!')


if __name__ == '__main__':
unittest.main()
Loading