From dddfbacf92805a64346e7ce77dd2864a240d2f42 Mon Sep 17 00:00:00 2001 From: connor-mccarthy Date: Fri, 6 Oct 2023 15:42:02 -0700 Subject: [PATCH 1/3] support dsl.OneOf --- sdk/RELEASE.md | 1 + sdk/python/kfp/compiler/compiler_test.py | 877 ++++++++++++++++++ sdk/python/kfp/compiler/compiler_utils.py | 247 +++-- .../kfp/compiler/pipeline_spec_builder.py | 122 ++- sdk/python/kfp/dsl/__init__.py | 3 + sdk/python/kfp/dsl/for_loop.py | 11 + sdk/python/kfp/dsl/pipeline_channel.py | 259 +++++- sdk/python/kfp/dsl/pipeline_channel_test.py | 224 ++++- sdk/python/kfp/dsl/pipeline_context.py | 4 + sdk/python/kfp/dsl/tasks_group.py | 13 + .../pipelines/if_elif_else_complex.py | 14 +- .../pipelines/if_elif_else_complex.yaml | 198 +++- ... => if_elif_else_with_oneof_parameters.py} | 24 +- .../if_elif_else_with_oneof_parameters.yaml | 420 +++++++++ sdk/python/test_data/pipelines/if_else.yaml | 214 ----- .../pipelines/if_else_with_oneof_artifacts.py | 60 ++ .../if_else_with_oneof_artifacts.yaml | 380 ++++++++ ...se.py => if_else_with_oneof_parameters.py} | 11 +- ...aml => if_else_with_oneof_parameters.yaml} | 139 +-- sdk/python/test_data/test_data_config.yaml | 13 +- 20 files changed, 2826 insertions(+), 408 deletions(-) rename sdk/python/test_data/pipelines/{if_elif_else.py => if_elif_else_with_oneof_parameters.py} (66%) create mode 100644 sdk/python/test_data/pipelines/if_elif_else_with_oneof_parameters.yaml delete mode 100644 sdk/python/test_data/pipelines/if_else.yaml create mode 100644 sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.py create mode 100644 sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.yaml rename sdk/python/test_data/pipelines/{if_else.py => if_else_with_oneof_parameters.py} (79%) rename sdk/python/test_data/pipelines/{if_elif_else.yaml => if_else_with_oneof_parameters.yaml} (72%) diff --git a/sdk/RELEASE.md b/sdk/RELEASE.md index 09b20e1b545..502f530072a 100644 --- a/sdk/RELEASE.md +++ b/sdk/RELEASE.md @@ -3,6 +3,7 @@ ## Features ## Breaking changes +* Support collecting outputs from conditional branches using `dsl.OneOf` [\#10067](https://github.com/kubeflow/pipelines/pull/10067) ## Deprecations diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 6cf07614616..1eb53de15dc 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -147,11 +147,36 @@ def print_hello(): print('hello') +@dsl.component +def cleanup(): + print('cleanup') + + @dsl.component def double(num: int) -> int: return 2 * num +@dsl.component +def print_and_return_as_artifact(text: str, a: Output[Artifact]): + print(text) + with open(a.path, 'w') as f: + f.write(text) + + +@dsl.component +def print_and_return_with_output_key(text: str, output_key: OutputPath(str)): + print(text) + with open(output_key, 'w') as f: + f.write(text) + + +@dsl.component +def print_artifact(a: Input[Artifact]): + with open(a.path) as f: + print(f.read()) + + ########### @@ -4140,6 +4165,44 @@ def my_pipeline( 'Component output artifact.') +class TestCannotReturnFromWithinControlFlowGroup(unittest.TestCase): + + def test_condition_raises(self): + with self.assertRaisesRegex( + compiler_utils.InvalidTopologyException, + r'Pipeline outputs may only be returned from the top level of the pipeline function scope\. Got pipeline output from within the control flow group dsl\.Condition\.' + ): + + @dsl.pipeline + def my_pipeline(string: str = 'string') -> str: + with dsl.Condition(string == 'foo'): + return print_and_return(text=string).output + + def test_loop_raises(self): + + with self.assertRaisesRegex( + compiler_utils.InvalidTopologyException, + r'Pipeline outputs may only be returned from the top level of the pipeline function scope\. Got pipeline output from within the control flow group dsl\.ParallelFor\.' + ): + + @dsl.pipeline + def my_pipeline(string: str = 'string') -> str: + with dsl.ParallelFor([1, 2, 3]): + return print_and_return(text=string).output + + def test_exit_handler_raises(self): + + with self.assertRaisesRegex( + compiler_utils.InvalidTopologyException, + r'Pipeline outputs may only be returned from the top level of the pipeline function scope\. Got pipeline output from within the control flow group dsl\.ExitHandler\.' + ): + + @dsl.pipeline + def my_pipeline(string: str = 'string') -> str: + with dsl.ExitHandler(print_and_return(text='exit task')): + return print_and_return(text=string).output + + class TestConditionLogic(unittest.TestCase): def test_if(self): @@ -4480,5 +4543,819 @@ def flip_coin_pipeline(): print_and_return(text='Got tails!') +class TestDslOneOf(unittest.TestCase): + # The space of possible tests is very large, so we test a representative set of cases covering the following styles of usage: + # - upstream conditions: if/else v if/elif/else + # - data consumed: parameters v artifacts + # - where dsl.OneOf goes: consumed by task v returned v both + # - when outputs have different keys: e.g., .output v .outputs[] + # - how the if/elif/else are nested and at what level they are consumed + + # Data type validation (e.g., dsl.OneOf(artifact, param) fails) and similar is covered in pipeline_channel_test.py. + + # To help narrow the tests further (we already test lots of aspects in the following cases), we choose focus on dsl.OneOf behavior, not the conditional logic if If/Elif/Else. This is more verbose, but more maintainable and the behavior under test is clearer. + + def test_if_else_returned(self): + # if/else + # returned + # parameters + # different output keys + + @dsl.pipeline + def roll_die_pipeline() -> str: + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Else(): + t2 = print_and_return_with_output_key(text='Got tails!') + return dsl.OneOf(t1.output, t2.outputs['output_key']) + + # hole punched through if + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-2'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-Output'].parameter_type, + type_utils.STRING, + ) + # hole punched through else + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-3'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-with-output-key-output_key'] + .parameter_type, + type_utils.STRING, + ) + # condition-branches surfaces + self.assertEqual( + roll_die_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .parameters['pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, + type_utils.STRING, + ) + parameter_selectors = roll_die_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( + parameter_selectors[0].output_parameter_key, + 'pipelinechannel--print-and-return-Output', + ) + self.assertEqual( + parameter_selectors[0].producer_subtask, + 'condition-2', + ) + self.assertEqual( + parameter_selectors[1].output_parameter_key, + 'pipelinechannel--print-and-return-with-output-key-output_key', + ) + self.assertEqual( + parameter_selectors[1].producer_subtask, + 'condition-3', + ) + # surfaced as output + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.outputs + .parameters['Output'].value_from_parameter.producer_subtask, + 'condition-branches-1', + ) + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.outputs + .parameters['Output'].value_from_parameter.output_parameter_key, + 'pipelinechannel--condition-branches-1-oneof-1', + ) + + def test_if_elif_else_returned(self): + # if/elif/else + # returned + # parameters + # different output keys + + @dsl.pipeline + def roll_die_pipeline() -> str: + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + with dsl.Else(): + t3 = print_and_return_with_output_key(text='Draw!') + return dsl.OneOf(t1.output, t2.output, t3.outputs['output_key']) + + # hole punched through if + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-2'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-Output'].parameter_type, + type_utils.STRING, + ) + # hole punched through elif + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-3'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-2-Output'].parameter_type, + type_utils.STRING, + ) + # hole punched through else + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-4'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-with-output-key-output_key'] + .parameter_type, + type_utils.STRING, + ) + # condition-branches surfaces + self.assertEqual( + roll_die_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .parameters['pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, + type_utils.STRING, + ) + parameter_selectors = roll_die_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( + parameter_selectors[0].output_parameter_key, + 'pipelinechannel--print-and-return-Output', + ) + self.assertEqual( + parameter_selectors[0].producer_subtask, + 'condition-2', + ) + self.assertEqual( + parameter_selectors[1].output_parameter_key, + 'pipelinechannel--print-and-return-2-Output', + ) + self.assertEqual( + parameter_selectors[1].producer_subtask, + 'condition-3', + ) + self.assertEqual( + parameter_selectors[2].output_parameter_key, + 'pipelinechannel--print-and-return-with-output-key-output_key', + ) + self.assertEqual( + parameter_selectors[2].producer_subtask, + 'condition-4', + ) + # surfaced as output + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.outputs + .parameters['Output'].value_from_parameter.producer_subtask, + 'condition-branches-1', + ) + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.outputs + .parameters['Output'].value_from_parameter.output_parameter_key, + 'pipelinechannel--condition-branches-1-oneof-1', + ) + + def test_if_elif_else_consumed(self): + # tests if/elif/else + # returned + # parameters + # different output keys + + @dsl.pipeline + def roll_die_pipeline(): + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + with dsl.Else(): + t3 = print_and_return_with_output_key(text='Draw!') + print_and_return( + text=dsl.OneOf(t1.output, t2.output, t3.outputs['output_key'])) + + # hole punched through if + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-2'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-Output'].parameter_type, + type_utils.STRING, + ) + # hole punched through elif + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-3'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-2-Output'].parameter_type, + type_utils.STRING, + ) + # hole punched through else + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-4'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-with-output-key-output_key'] + .parameter_type, + type_utils.STRING, + ) + # condition-branches surfaces + self.assertEqual( + roll_die_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .parameters['pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, + type_utils.STRING, + ) + parameter_selectors = roll_die_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( + parameter_selectors[0].output_parameter_key, + 'pipelinechannel--print-and-return-Output', + ) + self.assertEqual( + parameter_selectors[0].producer_subtask, + 'condition-2', + ) + self.assertEqual( + parameter_selectors[1].output_parameter_key, + 'pipelinechannel--print-and-return-2-Output', + ) + self.assertEqual( + parameter_selectors[1].producer_subtask, + 'condition-3', + ) + self.assertEqual( + parameter_selectors[2].output_parameter_key, + 'pipelinechannel--print-and-return-with-output-key-output_key', + ) + self.assertEqual( + parameter_selectors[2].producer_subtask, + 'condition-4', + ) + # consumed from condition-branches + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.tasks['print-and-return-3'] + .inputs.parameters['text'].task_output_parameter.producer_task, + 'condition-branches-1', + ) + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.tasks['print-and-return-3'] + .inputs.parameters['text'].task_output_parameter + .output_parameter_key, + 'pipelinechannel--condition-branches-1-oneof-1', + ) + + def test_if_else_consumed_and_returned(self): + # tests if/else + # consumed and returned + # parameters + # same output key + + @dsl.pipeline + def flip_coin_pipeline() -> str: + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return(text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return(text='Got tails!') + x = dsl.OneOf(print_task_1.output, print_task_2.output) + print_and_return(text=x) + return x + + # hole punched through if + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-2'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-Output'].parameter_type, + type_utils.STRING, + ) + # hole punched through else + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-3'] + .output_definitions.parameters[ + 'pipelinechannel--print-and-return-2-Output'].parameter_type, + type_utils.STRING, + ) + # condition-branches surfaces + self.assertEqual( + flip_coin_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .parameters['pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, + type_utils.STRING, + ) + parameter_selectors = flip_coin_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( + parameter_selectors[0].output_parameter_key, + 'pipelinechannel--print-and-return-Output', + ) + self.assertEqual( + parameter_selectors[0].producer_subtask, + 'condition-2', + ) + self.assertEqual( + parameter_selectors[1].output_parameter_key, + 'pipelinechannel--print-and-return-2-Output', + ) + self.assertEqual( + parameter_selectors[1].producer_subtask, + 'condition-3', + ) + # consumed from condition-branches + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag + .tasks['print-and-return-3'].inputs.parameters['text'] + .task_output_parameter.producer_task, + 'condition-branches-1', + ) + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag + .tasks['print-and-return-3'].inputs.parameters['text'] + .task_output_parameter.output_parameter_key, + 'pipelinechannel--condition-branches-1-oneof-1', + ) + + # surfaced as output + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag.outputs + .parameters['Output'].value_from_parameter.producer_subtask, + 'condition-branches-1', + ) + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag.outputs + .parameters['Output'].value_from_parameter.output_parameter_key, + 'pipelinechannel--condition-branches-1-oneof-1', + ) + + def test_if_else_consumed_and_returned_artifacts(self): + # tests if/else + # consumed and returned + # artifacts + # same output key + + @dsl.pipeline + def flip_coin_pipeline() -> Artifact: + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return_as_artifact(text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return_as_artifact(text='Got tails!') + x = dsl.OneOf(print_task_1.outputs['a'], print_task_2.outputs['a']) + print_artifact(a=x) + return x + + # hole punched through if + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-2'] + .output_definitions + .artifacts['pipelinechannel--print-and-return-as-artifact-a'] + .artifact_type.schema_title, + 'system.Artifact', + ) + # hole punched through else + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-3'] + .output_definitions + .artifacts['pipelinechannel--print-and-return-as-artifact-2-a'] + .artifact_type.schema_title, + 'system.Artifact', + ) + # condition-branches surfaces + self.assertEqual( + flip_coin_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .artifacts['pipelinechannel--condition-branches-1-oneof-1'] + .artifact_type.schema_title, + 'system.Artifact', + ) + artifact_selectors = flip_coin_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.artifacts[ + 'pipelinechannel--condition-branches-1-oneof-1'].artifact_selectors + self.assertEqual( + artifact_selectors[0].output_artifact_key, + 'pipelinechannel--print-and-return-as-artifact-a', + ) + self.assertEqual( + artifact_selectors[0].producer_subtask, + 'condition-2', + ) + self.assertEqual( + artifact_selectors[1].output_artifact_key, + 'pipelinechannel--print-and-return-as-artifact-2-a', + ) + self.assertEqual( + artifact_selectors[1].producer_subtask, + 'condition-3', + ) + # consumed from condition-branches + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag.tasks['print-artifact'] + .inputs.artifacts['a'].task_output_artifact.producer_task, + 'condition-branches-1', + ) + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag.tasks['print-artifact'] + .inputs.artifacts['a'].task_output_artifact.output_artifact_key, + 'pipelinechannel--condition-branches-1-oneof-1', + ) + + # surfaced as output + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag.outputs + .artifacts['Output'].artifact_selectors[0].producer_subtask, + 'condition-branches-1', + ) + self.assertEqual( + flip_coin_pipeline.pipeline_spec.root.dag.outputs + .artifacts['Output'].artifact_selectors[0].output_artifact_key, + 'pipelinechannel--condition-branches-1-oneof-1', + ) + + def test_nested_under_condition_consumed(self): + # nested under loop and condition + # artifact + + @dsl.pipeline + def flip_coin_pipeline(execute_pipeline: bool): + with dsl.If(execute_pipeline == True): + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return_as_artifact( + text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return_as_artifact( + text='Got tails!') + x = dsl.OneOf(print_task_1.outputs['a'], + print_task_2.outputs['a']) + print_artifact(a=x) + + # hole punched through if + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-3'] + .output_definitions + .artifacts['pipelinechannel--print-and-return-as-artifact-a'] + .artifact_type.schema_title, + 'system.Artifact', + ) + # hole punched through else + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-4'] + .output_definitions + .artifacts['pipelinechannel--print-and-return-as-artifact-2-a'] + .artifact_type.schema_title, + 'system.Artifact', + ) + # condition-branches surfaces + self.assertEqual( + flip_coin_pipeline.pipeline_spec + .components['comp-condition-branches-2'].output_definitions + .artifacts['pipelinechannel--condition-branches-2-oneof-1'] + .artifact_type.schema_title, + 'system.Artifact', + ) + artifact_selectors = flip_coin_pipeline.pipeline_spec.components[ + 'comp-condition-branches-2'].dag.outputs.artifacts[ + 'pipelinechannel--condition-branches-2-oneof-1'].artifact_selectors + self.assertEqual( + artifact_selectors[0].output_artifact_key, + 'pipelinechannel--print-and-return-as-artifact-a', + ) + self.assertEqual( + artifact_selectors[0].producer_subtask, + 'condition-3', + ) + self.assertEqual( + artifact_selectors[1].output_artifact_key, + 'pipelinechannel--print-and-return-as-artifact-2-a', + ) + self.assertEqual( + artifact_selectors[1].producer_subtask, + 'condition-4', + ) + # consumed from condition-branches + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag + .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact + .producer_task, + 'condition-branches-2', + ) + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag + .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact + .output_artifact_key, + 'pipelinechannel--condition-branches-2-oneof-1', + ) + + def test_nested_under_condition_returned_raises(self): + # nested under loop and condition + # artifact + with self.assertRaisesRegex( + compiler_utils.InvalidTopologyException, + f'Pipeline outputs may only be returned from the top level of the pipeline function scope\. Got pipeline output dsl\.OneOf from within the control flow group dsl\.If\.' + ): + + @dsl.pipeline + def flip_coin_pipeline(execute_pipeline: bool): + with dsl.If(execute_pipeline == True): + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return_as_artifact( + text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return_as_artifact( + text='Got tails!') + return dsl.OneOf(print_task_1.outputs['a'], + print_task_2.outputs['a']) + + def test_deeply_nested_consumed(self): + # nested under loop and condition and exit handler + # consumed + # artifact + + @dsl.pipeline + def flip_coin_pipeline(execute_pipeline: bool): + with dsl.ExitHandler(cleanup()): + with dsl.ParallelFor([1, 2, 3]): + with dsl.If(execute_pipeline == True): + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return_as_artifact( + text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return_as_artifact( + text='Got tails!') + x = dsl.OneOf(print_task_1.outputs['a'], + print_task_2.outputs['a']) + print_artifact(a=x) + + self.assertIn( + 'condition-branches-5', flip_coin_pipeline.pipeline_spec + .components['comp-condition-4'].dag.tasks) + # consumed from condition-branches + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-4'].dag + .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact + .producer_task, + 'condition-branches-5', + ) + self.assertEqual( + flip_coin_pipeline.pipeline_spec.components['comp-condition-4'].dag + .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact + .output_artifact_key, + 'pipelinechannel--condition-branches-5-oneof-1', + ) + + def test_deeply_nested_returned_raises(self): + # nested under loop and condition + # returned + # artifact + + with self.assertRaisesRegex( + compiler_utils.InvalidTopologyException, + f'Pipeline outputs may only be returned from the top level of the pipeline function scope\. Got pipeline output dsl\.OneOf from within the control flow group dsl\.ParallelFor\.' + ): + + @dsl.pipeline + def flip_coin_pipeline(execute_pipeline: bool) -> str: + with dsl.ExitHandler(cleanup()): + with dsl.If(execute_pipeline == True): + with dsl.ParallelFor([1, 2, 3]): + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return_as_artifact( + text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return_as_artifact( + text='Got tails!') + return dsl.OneOf(print_task_1.outputs['a'], + print_task_2.outputs['a']) + + def test_consume_at_wrong_level(self): + + with self.assertRaisesRegex( + compiler_utils.InvalidTopologyException, + f'Illegal task dependency across DSL context managers\. A downstream task cannot depend on an upstream task within a dsl\.If context unless the downstream is within that context too\. Found task print-artifact which depends on upstream task condition-branches-5 within an uncommon dsl\.If context\.' + ): + + @dsl.pipeline + def flip_coin_pipeline(execute_pipeline: bool): + with dsl.ExitHandler(cleanup()): + with dsl.ParallelFor([1, 2, 3]): + with dsl.If(execute_pipeline == True): + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return_as_artifact( + text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return_as_artifact( + text='Got tails!') + x = dsl.OneOf(print_task_1.outputs['a'], + print_task_2.outputs['a']) + # this is one level dedented from the permitted case + print_artifact(a=x) + + def test_return_at_wrong_level(self): + with self.assertRaisesRegex( + compiler_utils.InvalidTopologyException, + f'Pipeline outputs may only be returned from the top level of the pipeline function scope\. Got pipeline output dsl\.OneOf from within the control flow group dsl\.If\.' + ): + + @dsl.pipeline + def flip_coin_pipeline(execute_pipeline: bool): + with dsl.If(execute_pipeline == True): + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return_as_artifact( + text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return_as_artifact( + text='Got tails!') + # this is returned at the right level, but not permitted since it's still effectively returning from within the dsl.If group + return dsl.OneOf(print_task_1.outputs['a'], + print_task_2.outputs['a']) + + def test_consumed_in_nested_groups(self): + + @dsl.pipeline + def roll_die_pipeline( + repeat: bool = True, + rounds: List[str] = ['a', 'b', 'c'], + ): + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + with dsl.Else(): + t3 = print_and_return_with_output_key(text='Draw!') + x = dsl.OneOf(t1.output, t2.output, t3.outputs['output_key']) + + with dsl.ParallelFor(rounds): + with dsl.If(repeat == True): + print_and_return(text=x) + + # condition-branches surfaces + self.assertEqual( + roll_die_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .parameters['pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, + type_utils.STRING, + ) + parameter_selectors = roll_die_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( + parameter_selectors[0].output_parameter_key, + 'pipelinechannel--print-and-return-Output', + ) + self.assertEqual( + parameter_selectors[0].producer_subtask, + 'condition-2', + ) + self.assertEqual( + parameter_selectors[1].output_parameter_key, + 'pipelinechannel--print-and-return-2-Output', + ) + self.assertEqual( + parameter_selectors[1].producer_subtask, + 'condition-3', + ) + self.assertEqual( + parameter_selectors[2].output_parameter_key, + 'pipelinechannel--print-and-return-with-output-key-output_key', + ) + self.assertEqual( + parameter_selectors[2].producer_subtask, + 'condition-4', + ) + # condition points to correct upstream output + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-6'] + .input_definitions.parameters[ + 'pipelinechannel--condition-branches-1-pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, type_utils.STRING) + # inner task consumes from condition input parameter + self.assertEqual( + roll_die_pipeline.pipeline_spec.components['comp-condition-6'].dag + .tasks['print-and-return-3'].inputs.parameters['text'] + .component_input_parameter, + 'pipelinechannel--condition-branches-1-pipelinechannel--condition-branches-1-oneof-1' + ) + + def test_oneof_in_fstring(self): + with self.assertRaisesRegex( + NotImplementedError, + f'dsl\.OneOf is not yet supported in f-strings\.'): + + @dsl.pipeline + def roll_die_pipeline(): + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + with dsl.Else(): + t3 = print_and_return_with_output_key(text='Draw!') + print_and_return( + text=f"Final result: {dsl.OneOf(t1.output, t2.output, t3.outputs['output_key'])}" + ) + + def test_oneof_in_condition(self): + + @dsl.pipeline + def roll_die_pipeline(repeat_on: str = 'Got heads!'): + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + with dsl.Else(): + t3 = print_and_return_with_output_key(text='Draw!') + x = dsl.OneOf(t1.output, t2.output, t3.outputs['output_key']) + + with dsl.If(x == repeat_on): + print_and_return(text=x) + + # condition-branches surfaces + self.assertEqual( + roll_die_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .parameters['pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, + type_utils.STRING, + ) + parameter_selectors = roll_die_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( + parameter_selectors[0].output_parameter_key, + 'pipelinechannel--print-and-return-Output', + ) + self.assertEqual( + parameter_selectors[0].producer_subtask, + 'condition-2', + ) + self.assertEqual( + parameter_selectors[1].output_parameter_key, + 'pipelinechannel--print-and-return-2-Output', + ) + self.assertEqual( + parameter_selectors[1].producer_subtask, + 'condition-3', + ) + self.assertEqual( + parameter_selectors[2].output_parameter_key, + 'pipelinechannel--print-and-return-with-output-key-output_key', + ) + self.assertEqual( + parameter_selectors[2].producer_subtask, + 'condition-4', + ) + # condition points to correct upstream output + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.tasks['condition-5'] + .trigger_policy.condition, + "inputs.parameter_values['pipelinechannel--condition-branches-1-pipelinechannel--condition-branches-1-oneof-1'] == inputs.parameter_values['pipelinechannel--repeat_on']" + ) + + def test_type_checking_parameters(self): + with self.assertRaisesRegex( + type_utils.InconsistentTypeException, + "Incompatible argument passed to the input 'val' of component 'print-int': Argument type 'STRING' is incompatible with the input type 'NUMBER_INTEGER'", + ): + + @dsl.component + def print_int(val: int): + print(val) + + @dsl.pipeline + def roll_die_pipeline(): + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + with dsl.Else(): + t3 = print_and_return_with_output_key(text='Draw!') + print_int( + val=dsl.OneOf(t1.output, t2.output, + t3.outputs['output_key'])) + + def test_oneof_of_oneof(self): + with self.assertRaisesRegex( + ValueError, + r'dsl.OneOf cannot be used inside of another dsl\.OneOf\.'): + + @dsl.pipeline + def roll_die_pipeline() -> str: + outer_flip_coin_task = flip_coin() + with dsl.If(outer_flip_coin_task.output == 'heads'): + inner_flip_coin_task = flip_coin() + with dsl.If(inner_flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Else(): + t2 = print_and_return(text='Got tails!') + t3 = dsl.OneOf(t1.output, t2.output) + with dsl.Else(): + t4 = print_and_return(text='First flip was not heads!') + return dsl.OneOf(t3, t4.output) + + if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/compiler/compiler_utils.py b/sdk/python/kfp/compiler/compiler_utils.py index ccc6730b1e3..0bc525e250f 100644 --- a/sdk/python/kfp/compiler/compiler_utils.py +++ b/sdk/python/kfp/compiler/compiler_utils.py @@ -14,9 +14,10 @@ """Utility methods for compiler implementation that is IR-agnostic.""" import collections -from copy import deepcopy +import copy from typing import DefaultDict, Dict, List, Mapping, Set, Tuple, Union +from kfp import dsl from kfp.dsl import for_loop from kfp.dsl import pipeline_channel from kfp.dsl import pipeline_context @@ -258,10 +259,9 @@ def get_inputs_for_all_groups( if isinstance(channel_to_add, pipeline_channel.PipelineChannel): channels_to_add.append(channel_to_add) - if channel.task_name: + if channel.task: # The PipelineChannel is produced by a task. - - upstream_task = pipeline.tasks[channel.task_name] + upstream_task = channel.task upstream_groups, downstream_groups = ( _get_uncommon_ancestors( task_name_to_parent_groups=task_name_to_parent_groups, @@ -462,46 +462,116 @@ def get_outputs_for_all_groups( } outputs = collections.defaultdict(dict) - + processed_oneofs: Set[pipeline_channel.OneOfMixin] = set() # handle dsl.Collected consumed by tasks for task in pipeline.tasks.values(): for channel in task.channel_inputs: - if not isinstance(channel, for_loop.Collected): - continue - producer_task = pipeline.tasks[channel.task_name] - consumer_task = task - - upstream_groups, downstream_groups = ( - _get_uncommon_ancestors( - task_name_to_parent_groups=task_name_to_parent_groups, - group_name_to_parent_groups=group_name_to_parent_groups, - task1=producer_task, - task2=consumer_task, - )) - validate_parallel_for_fan_in_consumption_legal( - consumer_task_name=consumer_task.name, - upstream_groups=upstream_groups, - group_name_to_group=group_name_to_group, - ) + # TODO: migrate Collected to OneOfMixin style implementation, + # then simplify this logic to align with OneOfMixin logic + if isinstance(channel, dsl.Collected): + producer_task = pipeline.tasks[channel.task_name] + consumer_task = task + + upstream_groups, downstream_groups = ( + _get_uncommon_ancestors( + task_name_to_parent_groups=task_name_to_parent_groups, + group_name_to_parent_groups=group_name_to_parent_groups, + task1=producer_task, + task2=consumer_task, + )) + validate_parallel_for_fan_in_consumption_legal( + consumer_task_name=consumer_task.name, + upstream_groups=upstream_groups, + group_name_to_group=group_name_to_group, + ) + + # producer_task's immediate parent group and the name by which + # to surface the channel + surfaced_output_name = additional_input_name_for_pipeline_channel( + channel) + + # the highest-level task group that "consumes" the + # collected output + parent_consumer = downstream_groups[0] + producer_task_name = upstream_groups.pop() + + # process from the upstream groups from the inside out + for upstream_name in reversed(upstream_groups): + outputs[upstream_name][ + surfaced_output_name] = make_new_channel_for_collected_outputs( + channel_name=channel.name, + starting_channel=channel.output, + task_name=producer_task_name, + ) + + # on each iteration, mutate the channel being consumed so + # that it references the last parent group surfacer + channel.name = surfaced_output_name + channel.task_name = upstream_name + + # for the next iteration, set the consumer to the current + # surfacer (parent group) + producer_task_name = upstream_name - # producer_task's immediate parent group and the name by which - # to surface the channel + parent_of_current_surfacer = group_name_to_parent_groups[ + upstream_name][-2] + if parent_consumer in group_name_to_children[ + parent_of_current_surfacer]: + break + + elif isinstance(channel, pipeline_channel.OneOfMixin): + for inner_channel in channel.channels: + producer_task = pipeline.tasks[inner_channel.task_name] + consumer_task = task + upstream_groups, downstream_groups = ( + _get_uncommon_ancestors( + task_name_to_parent_groups=task_name_to_parent_groups, + group_name_to_parent_groups=group_name_to_parent_groups, + task1=producer_task, + task2=consumer_task, + )) + surfaced_output_name = additional_input_name_for_pipeline_channel( + inner_channel) + + # 1. get the oneof + # 2. find the task group that surfaced it + # 3. find the inner tasks reponsible + + for upstream_name in reversed(upstream_groups): + # skip the first task processed, since we don't need to add new outputs for the innermost task + if upstream_name == inner_channel.task.name: + continue + # # once we've hit the outermost condition-branches group, we're done + if upstream_name == channel.condition_branches_group.name: + outputs[upstream_name][channel.name] = channel + break + + # copy so we can update the inner channel for the next iteration + # use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline + # this uses more memory than needed and some objects are uncopiable + outputs[upstream_name][ + surfaced_output_name] = copy.copy(inner_channel) + + inner_channel.name = surfaced_output_name + inner_channel.task_name = upstream_name + + processed_oneofs.add(channel) + + # handle dsl.Collected returned from pipeline + # TODO: consider migrating dsl.Collected returns to pattern used by dsl.OneOf, where the OneOf constructor returns a parameter/artifact channel, which fits in more cleanly into the existing compiler abtractions + for output_key, channel in pipeline_outputs_dict.items(): + if isinstance(channel, for_loop.Collected): surfaced_output_name = additional_input_name_for_pipeline_channel( channel) - - # the highest-level task group that "consumes" the - # collected output - parent_consumer = downstream_groups[0] + upstream_groups = task_name_to_parent_groups[channel.task_name][1:] producer_task_name = upstream_groups.pop() - - # process from the upstream groups from the inside out + # process upstream groups from the inside out, until getting to the pipeline level for upstream_name in reversed(upstream_groups): - outputs[upstream_name][ - surfaced_output_name] = make_new_channel_for_collected_outputs( - channel_name=channel.name, - starting_channel=channel.output, - task_name=producer_task_name, - ) + new_channel = make_new_channel_for_collected_outputs( + channel_name=channel.name, + starting_channel=channel.output, + task_name=producer_task_name, + ) # on each iteration, mutate the channel being consumed so # that it references the last parent group surfacer @@ -511,46 +581,46 @@ def get_outputs_for_all_groups( # for the next iteration, set the consumer to the current # surfacer (parent group) producer_task_name = upstream_name - - parent_of_current_surfacer = group_name_to_parent_groups[ - upstream_name][-2] - if parent_consumer in group_name_to_children[ - parent_of_current_surfacer]: - break - - # handle dsl.Collected returned from pipeline - for output_key, channel in pipeline_outputs_dict.items(): - if isinstance(channel, for_loop.Collected): - surfaced_output_name = additional_input_name_for_pipeline_channel( - channel) + outputs[upstream_name][surfaced_output_name] = new_channel + + # after surfacing from all inner TasksGroup, change the PipelineChannel output to also return from the correct TasksGroup + pipeline_outputs_dict[ + output_key] = make_new_channel_for_collected_outputs( + channel_name=surfaced_output_name, + starting_channel=channel.output, + task_name=upstream_name, + ) + elif isinstance(channel, pipeline_channel.OneOfMixin): + # if the output has already been consumed by a task before it is returned, we don't need to reprocess it + if channel in processed_oneofs: + continue + for inner_channel in channel.channels: + producer_task = pipeline.tasks[inner_channel.task_name] upstream_groups = task_name_to_parent_groups[ - channel.task_name][1:] - producer_task_name = upstream_groups.pop() - # process upstream groups from the inside out, until getting to the pipeline level - for upstream_name in reversed(upstream_groups): - new_channel = make_new_channel_for_collected_outputs( - channel_name=channel.name, - starting_channel=channel.output, - task_name=producer_task_name, - ) - - # on each iteration, mutate the channel being consumed so - # that it references the last parent group surfacer - channel.name = surfaced_output_name - channel.task_name = upstream_name + inner_channel.task_name][1:] + surfaced_output_name = additional_input_name_for_pipeline_channel( + inner_channel) - # for the next iteration, set the consumer to the current - # surfacer (parent group) - producer_task_name = upstream_name - outputs[upstream_name][surfaced_output_name] = new_channel - - # after surfacing from all inner TasksGroup, change the PipelineChannel output to also return from the correct TasksGroup - pipeline_outputs_dict[ - output_key] = make_new_channel_for_collected_outputs( - channel_name=surfaced_output_name, - starting_channel=channel.output, - task_name=upstream_name, - ) + # 1. get the oneof + # 2. find the task group that surfaced it + # 3. find the inner tasks reponsible + for upstream_name in reversed(upstream_groups): + # skip the first task processed, since we don't need to add new outputs for the innermost task + if upstream_name == inner_channel.task.name: + continue + # # once we've hit the outermost condition-branches group, we're done + if upstream_name == channel.condition_branches_group.name: + outputs[upstream_name][channel.name] = channel + break + + # copy so we can update the inner channel for the next iteration + # use copy not deepcopy, since deepcopy will needlessly copy the entire pipeline + # this uses more memory than needed and some objects are uncopiable + outputs[upstream_name][surfaced_output_name] = copy.copy( + inner_channel) + + inner_channel.name = surfaced_output_name + inner_channel.task_name = upstream_name return outputs, pipeline_outputs_dict @@ -633,22 +703,17 @@ def get_dependencies( """ dependencies = collections.defaultdict(set) for task in pipeline.tasks.values(): - upstream_task_names = set() + upstream_task_names: Set[Union[pipeline_task.PipelineTask, + tasks_group.TasksGroup]] = set() task_condition_inputs = list(condition_channels[task.name]) - for channel in task.channel_inputs + task_condition_inputs: - if channel.task_name: - upstream_task_names.add(channel.task_name) - upstream_task_names |= set(task.dependent_tasks) - - for upstream_task_name in upstream_task_names: - # the dependent op could be either a BaseOp or an opsgroup - if upstream_task_name in pipeline.tasks: - upstream_task = pipeline.tasks[upstream_task_name] - elif upstream_task_name in group_name_to_group: - upstream_task = group_name_to_group[upstream_task_name] - else: - raise ValueError( - f'Compiler cannot find task: {upstream_task_name}.') + all_channels = task.channel_inputs + task_condition_inputs + upstream_task_names.update( + {channel.task for channel in all_channels if channel.task}) + # dependent tasks is tasks on which .after was called and can only be the names of PipelineTasks, not TasksGroups + upstream_task_names.update( + {pipeline.tasks[after_task] for after_task in task.dependent_tasks}) + + for upstream_task in upstream_task_names: upstream_groups, downstream_groups = _get_uncommon_ancestors( task_name_to_parent_groups=task_name_to_parent_groups, @@ -658,7 +723,7 @@ def get_dependencies( ) # uncommon upstream ancestor check - uncommon_upstream_groups = deepcopy(upstream_groups) + uncommon_upstream_groups = copy.deepcopy(upstream_groups) uncommon_upstream_groups.remove( upstream_task.name ) # because a task's `upstream_groups` contains the task's name @@ -675,6 +740,8 @@ def get_dependencies( raise InvalidTopologyException( f'{ILLEGAL_CROSS_DAG_ERROR_PREFIX} A downstream task cannot depend on an upstream task within a dsl.{dependent_group.__class__.__name__} context unless the downstream is within that context too. Found task {task.name} which depends on upstream task {upstream_task.name} within an uncommon dsl.{dependent_group.__class__.__name__} context.' ) + # TODO: migrate Collected to OneOfMixin style implementation, + # then make this validation dsl.Collected-aware elif isinstance(dependent_group, tasks_group.ParallelFor): raise InvalidTopologyException( f'{ILLEGAL_CROSS_DAG_ERROR_PREFIX} A downstream task cannot depend on an upstream task within a dsl.{dependent_group.__class__.__name__} context unless the downstream is within that context too or the outputs are begin fanned-in to a list using dsl.{for_loop.Collected.__name__}. Found task {task.name} which depends on upstream task {upstream_task.name} within an uncommon dsl.{dependent_group.__class__.__name__} context.' diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index 1c0b7aa4635..1f972133c79 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -417,12 +417,13 @@ def _build_component_spec_from_component_spec_structure( return component_spec -def _connect_dag_outputs( +def connect_single_dag_output( component_spec: pipeline_spec_pb2.ComponentSpec, output_name: str, output_channel: pipeline_channel.PipelineChannel, ) -> None: - """Connects dag output to a subtask output. + """Connects a DAG output to a subtask output when the subtask output + contains only one channel (i.e., not OneOfMixin). Args: component_spec: The component spec to modify its dag outputs. @@ -451,14 +452,71 @@ def _connect_dag_outputs( output_name].value_from_parameter.output_parameter_key = output_channel.name +def connect_oneof_dag_output( + component_spec: pipeline_spec_pb2.ComponentSpec, + output_name: str, + oneof_output: pipeline_channel.OneOfMixin, +) -> None: + """Connects a output to the OneOf output returned by the DAG's internal + condition-branches group. + + Args: + component_spec: The component spec to modify its DAG outputs. + output_name: The name of the DAG output. + oneof_output: The OneOfMixin object returned by the pipeline (OneOf in user code). + """ + if isinstance(oneof_output, pipeline_channel.OneOfArtifact): + if output_name not in component_spec.output_definitions.artifacts: + raise ValueError( + f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.' + ) + for channel in oneof_output.channels: + component_spec.dag.outputs.artifacts[ + output_name].artifact_selectors.append( + pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec( + producer_subtask=channel.task_name, + output_artifact_key=channel.name, + )) + if isinstance(oneof_output, pipeline_channel.OneOfParameter): + if output_name not in component_spec.output_definitions.parameters: + raise ValueError( + f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.' + ) + for channel in oneof_output.channels: + component_spec.dag.outputs.parameters[ + output_name].value_from_oneof.parameter_selectors.append( + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + producer_subtask=channel.task_name, + output_parameter_key=channel.name, + )) + + def _build_dag_outputs( component_spec: pipeline_spec_pb2.ComponentSpec, dag_outputs: Dict[str, pipeline_channel.PipelineChannel], ) -> None: - """Builds DAG output spec.""" + """Connects the DAG's outputs to a TaskGroup's ComponentSpec and validates + it is present in the component interface. + + Args: + component_spec: The ComponentSpec. + dag_outputs: Dictionary of output key to output channel. + """ for output_name, output_channel in dag_outputs.items(): - _connect_dag_outputs(component_spec, output_name, output_channel) - # Valid dag outputs covers all outptus in component definition. + if not isinstance(output_channel, pipeline_channel.PipelineChannel): + raise ValueError( + f"Got unknown pipeline output '{output_name}' of type {output_channel}." + ) + connect_single_dag_output(component_spec, output_name, output_channel) + + validate_dag_outputs(component_spec) + + +def validate_dag_outputs( + component_spec: pipeline_spec_pb2.ComponentSpec) -> None: + """Validates the DAG's ComponentSpec specifies the source task for all of + its ComponentSpec inputs (input_definitions) and outputs + (output_definitions).""" for output_name in component_spec.output_definitions.artifacts: if output_name not in component_spec.dag.outputs.artifacts: raise ValueError(f'Missing pipeline output: {output_name}.') @@ -467,6 +525,31 @@ def _build_dag_outputs( raise ValueError(f'Missing pipeline output: {output_name}.') +def build_oneof_dag_outputs( + component_spec: pipeline_spec_pb2.ComponentSpec, + oneof_outputs: Dict[str, pipeline_channel.OneOfMixin], +) -> None: + """Connects the DAG's OneOf outputs to a TaskGroup's ComponentSpec and + validates it is present in the component interface. + + Args: + component_spec: The ComponentSpec. + oneof_outputs: Dictionary of output key to OneOf output channel. + """ + for output_name, oneof_output in oneof_outputs.items(): + for channel in oneof_output.channels: + if not isinstance(channel, pipeline_channel.PipelineChannel): + raise ValueError( + f"Got unknown pipeline output '{output_name}' of type {type(channel)}." + ) + connect_oneof_dag_output( + component_spec, + output_name, + oneof_output, + ) + validate_dag_outputs(component_spec) + + def build_importer_spec_for_task( task: pipeline_task.PipelineTask ) -> pipeline_spec_pb2.PipelineDeploymentConfig.ImporterSpec: @@ -1290,7 +1373,7 @@ def build_spec_by_group( elif isinstance(subgroup, tasks_group.ConditionBranches): subgroup_component_spec = build_component_spec_for_group( input_pipeline_channels=subgroup_input_channels, - output_pipeline_channels={}, + output_pipeline_channels=subgroup_output_channels, ) subgroup_task_spec = build_task_spec_for_group( @@ -1299,6 +1382,9 @@ def build_spec_by_group( tasks_in_current_dag=tasks_in_current_dag, is_parent_component_root=is_parent_component_root, ) + # oneof is the only type of output a ConditionBranches group can have + build_oneof_dag_outputs(subgroup_component_spec, + subgroup_output_channels) else: raise RuntimeError( @@ -1694,6 +1780,28 @@ def _rename_component_refs( old_name_to_new_name[old_component_name]].CopyFrom(component_spec) +def validate_pipeline_outputs_dict( + pipeline_outputs_dict: Dict[str, pipeline_channel.PipelineChannel]): + for channel in pipeline_outputs_dict.values(): + if isinstance(channel, for_loop.Collected): + # this validation doesn't apply to Collected + continue + + elif isinstance(channel, pipeline_channel.OneOfMixin): + if channel.condition_branches_group.parent_task_group.group_type != tasks_group.TasksGroupType.PIPELINE: + raise compiler_utils.InvalidTopologyException( + f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output dsl.{pipeline_channel.OneOf.__name__} from within the control flow group dsl.{channel.condition_branches_group.parent_task_group.__class__.__name__}.' + ) + + elif isinstance(channel, pipeline_channel.PipelineChannel): + if channel.task.parent_task_group.group_type != tasks_group.TasksGroupType.PIPELINE: + raise compiler_utils.InvalidTopologyException( + f'Pipeline outputs may only be returned from the top level of the pipeline function scope. Got pipeline output from within the control flow group dsl.{channel.task.parent_task_group.__class__.__name__}.' + ) + else: + raise ValueError(f'Got unknown pipeline output: {channel}.') + + def create_pipeline_spec( pipeline: pipeline_context.Pipeline, component_spec: structures.ComponentSpec, @@ -1729,6 +1837,8 @@ def create_pipeline_spec( # an output from a task in a condition group, for example, which isn't # caught until submission time using Vertex SDK client pipeline_outputs_dict = convert_pipeline_outputs_to_dict(pipeline_outputs) + validate_pipeline_outputs_dict(pipeline_outputs_dict) + root_group = pipeline.groups[0] all_groups = compiler_utils.get_all_groups(root_group) diff --git a/sdk/python/kfp/dsl/__init__.py b/sdk/python/kfp/dsl/__init__.py index 001226b02cf..c2c70c847d5 100644 --- a/sdk/python/kfp/dsl/__init__.py +++ b/sdk/python/kfp/dsl/__init__.py @@ -229,8 +229,10 @@ def my_pipeline(): if os.environ.get('_KFP_RUNTIME', 'false') != 'true': from kfp.dsl.component_decorator import component from kfp.dsl.container_component_decorator import container_component + # TODO: Collected should be moved to pipeline_channel.py, consistent with OneOf from kfp.dsl.for_loop import Collected from kfp.dsl.importer_node import importer + from kfp.dsl.pipeline_channel import OneOf from kfp.dsl.pipeline_context import pipeline from kfp.dsl.pipeline_task import PipelineTask from kfp.dsl.placeholders import ConcatPlaceholder @@ -252,6 +254,7 @@ def my_pipeline(): 'If', 'Elif', 'Else', + 'OneOf', 'ExitHandler', 'ParallelFor', 'Collected', diff --git a/sdk/python/kfp/dsl/for_loop.py b/sdk/python/kfp/dsl/for_loop.py index 53815766315..e49c9a2951e 100644 --- a/sdk/python/kfp/dsl/for_loop.py +++ b/sdk/python/kfp/dsl/for_loop.py @@ -274,6 +274,7 @@ def _get_name_override(self, loop_arg_name: str, subvar_name: str) -> str: return f'{loop_arg_name}{self.SUBVAR_NAME_DELIMITER}{subvar_name}' +# TODO: migrate Collected to OneOfMixin style implementation class Collected(pipeline_channel.PipelineChannel): """For collecting into a list the output from a task in dsl.ParallelFor loops. @@ -313,3 +314,13 @@ def __init__( channel_type=channel_type, task_name=output.task_name, ) + self._validate_no_oneof_channel(self.output) + + def _validate_no_oneof_channel( + self, channel: Union[pipeline_channel.PipelineParameterChannel, + pipeline_channel.PipelineArtifactChannel] + ) -> None: + if isinstance(channel, pipeline_channel.OneOfMixin): + raise ValueError( + f'dsl.{pipeline_channel.OneOf.__name__} cannot be used inside of dsl.{Collected.__name__}.' + ) diff --git a/sdk/python/kfp/dsl/pipeline_channel.py b/sdk/python/kfp/dsl/pipeline_channel.py index 4841928bbf4..7f93ca38dc0 100644 --- a/sdk/python/kfp/dsl/pipeline_channel.py +++ b/sdk/python/kfp/dsl/pipeline_channel.py @@ -102,12 +102,31 @@ def __init__( self.task_name = task_name or None from kfp.dsl import pipeline_context - default_pipeline = pipeline_context.Pipeline.get_default_pipeline() - if self.task_name is not None and default_pipeline is not None and default_pipeline.tasks: - self.task = pipeline_context.Pipeline.get_default_pipeline().tasks[ - self.task_name] - else: - self.task = None + self.pipeline = pipeline_context.Pipeline.get_default_pipeline() + + @property + def task(self) -> Union['PipelineTask', 'TasksGroup']: + # TODO: migrate Collected to OneOfMixin style implementation, + # then move this out of a property + if self.task_name is None or self.pipeline is None: + return None + + if self.task_name in self.pipeline.tasks: + return self.pipeline.tasks[self.task_name] + + from kfp.compiler import compiler_utils + all_groups = compiler_utils.get_all_groups(self.pipeline.groups[0]) + # pipeline hasn't exited, so it doesn't have a name + all_groups_no_pipeline = all_groups[1:] + group_name_to_group = { + group.name: group for group in all_groups_no_pipeline + } + if self.task_name in group_name_to_group: + return group_name_to_group[self.task_name] + + raise ValueError( + f"PipelineChannel task name '{self.task_name}' not found in pipeline." + ) @property def full_name(self) -> str: @@ -265,6 +284,234 @@ def __init__( ) +class OneOfMixin(PipelineChannel): + """Shared functionality for OneOfParameter and OneOfAritfact.""" + + def _set_condition_branches_group( + self, channels: List[Union[PipelineParameterChannel, + PipelineArtifactChannel]] + ) -> None: + # avoid circular import + from kfp.dsl import tasks_group + + # .condition_branches_group could really be collapsed into just .task, + # but we prefer keeping both for clarity in the rest of the compiler + # code. When the code is logically related to a + # condition_branches_group, it aids understanding to reference this + # attribute name. When the code is trying to treat the OneOfMixin like + # a typical PipelineChannel, it aids to reference task. + self.condition_branches_group: tasks_group.ConditionBranches = channels[ + 0].task.parent_task_group.parent_task_group + + def _make_oneof_name(self) -> str: + # avoid circular imports + from kfp.compiler import compiler_utils + + # This is a different type of "injected channel". + # We know that this output will _always_ be a pipeline channel, so we + # set the pipeline-channel-- prefix immediately (here). + # In the downstream compiler logic, we get to treat this output like a + # normal task output. + return compiler_utils.additional_input_name_for_pipeline_channel( + f'{self.condition_branches_group.name}-oneof-{self.condition_branches_group._get_oneof_id()}' + ) + + def _validate_channels( + self, + channels: List[Union[PipelineParameterChannel, + PipelineArtifactChannel]], + ): + self._validate_no_collected_channel(channels) + self._validate_no_oneof_channel(channels) + self._validate_no_mix_of_parameters_and_artifacts(channels) + self._validate_has_else_group(self.condition_branches_group) + + def _validate_no_collected_channel( + self, channels: List[Union[PipelineParameterChannel, + PipelineArtifactChannel]] + ) -> None: + # avoid circular imports + from kfp.dsl import for_loop + if any(isinstance(channel, for_loop.Collected) for channel in channels): + raise ValueError( + f'dsl.{for_loop.Collected.__name__} cannot be used inside of dsl.{OneOf.__name__}.' + ) + + def _validate_no_oneof_channel( + self, channels: List[Union[PipelineParameterChannel, + PipelineArtifactChannel]] + ) -> None: + if any(isinstance(channel, OneOfMixin) for channel in channels): + raise ValueError( + f'dsl.{OneOf.__name__} cannot be used inside of another dsl.{OneOf.__name__}.' + ) + + def _validate_no_mix_of_parameters_and_artifacts( + self, channels: List[Union[PipelineParameterChannel, + PipelineArtifactChannel]] + ) -> None: + readable_name_map = { + PipelineParameterChannel: 'parameter', + PipelineArtifactChannel: 'artifact', + OneOfParameter: 'parameter', + OneOfArtifact: 'artifact', + } + # if channels[0] is any subclass of a PipelineParameterChannel + # check the rest of the channels against that parent type + # ensures check permits OneOfParameter and PipelineParameterChannel + # to be passed to OneOf together + if isinstance(channels[0], PipelineParameterChannel): + expected_type = PipelineParameterChannel + else: + expected_type = PipelineArtifactChannel + + for i, channel in enumerate(channels[1:], start=1): + if not isinstance(channel, expected_type): + raise TypeError( + f'Task outputs passed to dsl.{OneOf.__name__} must be the same type. Got two channels with different types: {readable_name_map[expected_type]} at index 0 and {readable_name_map[type(channel)]} at index {i}.' + ) + + def _validate_has_else_group( + self, + parent_group: 'tasks_group.ConditionBranches', + ) -> None: + # avoid circular imports + from kfp.dsl import tasks_group + if not isinstance(parent_group.groups[-1], tasks_group.Else): + raise ValueError( + f'dsl.{OneOf.__name__} must include an output from a task in a dsl.{tasks_group.Else.__name__} group to ensure at least one output is available at runtime.' + ) + + def __str__(self): + # supporting oneof in f-strings is technically feasible, but would + # require somehow encoding all of the oneof channels into the + # f-string + # another way to do this would be to maintain a pipeline-level + # map of PipelineChannels and encode a lookup key in the f-string + # the combination of OneOf and an f-string is not common, so prefer + # deferring implementation + raise NotImplementedError( + f'dsl.{OneOf.__name__} is not yet supported in f-strings.') + + @property + def pattern(self) -> str: + # override self.pattern to avoid calling __str__, allowing us to block f-strings for now + # this makes it OneOfMixin hashable for use in sets/dicts + task_name = self.task_name or '' + name = self.name + channel_type = self.channel_type or '' + if isinstance(channel_type, dict): + channel_type = json.dumps(channel_type) + return _PIPELINE_CHANNEL_PLACEHOLDER_TEMPLATE % (task_name, name, + channel_type) + + +# splitting out OneOf into subclasses significantly decreases the amount of +# branching in downstream compiler logic, since the +# isinstance(, PipelineParameterChannel/PipelineArtifactChannel) +# checks continue to behave in desirable ways +class OneOfParameter(PipelineParameterChannel, OneOfMixin): + """OneOf that results in an parameter channel for all downstream tasks.""" + + def __init__(self, channels: List[PipelineParameterChannel]) -> None: + self.channels = channels + self._set_condition_branches_group(channels) + super().__init__( + name=self._make_oneof_name(), + channel_type=channels[0].channel_type, + task_name=None, + ) + self.task_name = self.condition_branches_group.name + self.channels = channels + self._validate_channels(channels) + self._validate_same_kfp_type(channels) + + def _validate_same_kfp_type( + self, channels: List[PipelineParameterChannel]) -> None: + expected_type = channels[0].channel_type + for i, channel in enumerate(channels[1:], start=1): + if channel.channel_type != expected_type: + raise TypeError( + f'Task outputs passed to dsl.{OneOf.__name__} must be the same type. Got two channels with different types: {expected_type} at index 0 and {channel.channel_type} at index {i}.' + ) + + +class OneOfArtifact(PipelineArtifactChannel, OneOfMixin): + """OneOf that results in an artifact channel for all downstream tasks.""" + + def __init__(self, channels: List[PipelineArtifactChannel]) -> None: + self.channels = channels + self._set_condition_branches_group(channels) + super().__init__( + name=self._make_oneof_name(), + channel_type=channels[0].channel_type, + task_name=None, + is_artifact_list=channels[0].is_artifact_list, + ) + self.task_name = self.condition_branches_group.name + self._validate_channels(channels) + self._validate_same_kfp_type(channels) + + def _validate_same_kfp_type( + self, channels: List[PipelineArtifactChannel]) -> None: + # Unlike for component interface type checking where anything is + # passable to Artifact, we should require the output artifacts for a + # OneOf to be the same. This reduces the complexity/ambiguity for the + # user of the actual type checking logic. What should the type checking + # behavior be if the OneOf surfaces an Artifact and a Dataset? We can + # always loosen backward compatibly in the future, so prefer starting + # conservatively. + expected_type = channels[0].channel_type + expected_is_list = channels[0].is_artifact_list + for i, channel in enumerate(channels[1:], start=1): + if channel.channel_type != expected_type or channel.is_artifact_list != expected_is_list: + raise TypeError( + f'Task outputs passed to dsl.{OneOf.__name__} must be the same type. Got two channels with different types: {expected_type} at index 0 and {channel.channel_type} at index {i}.' + ) + + +class OneOf: + """For collecting mutually exclusive outputs from conditional branches into + a single pipeline channel. + + Args: + channels: The channels to collect into a OneOf. Must be of the same type. + + Example: + :: + + @dsl.pipeline + def flip_coin_pipeline() -> str: + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + print_task_1 = print_and_return(text='Got heads!') + with dsl.Else(): + print_task_2 = print_and_return(text='Got tails!') + + # use the output from the branch that gets executed + oneof = dsl.OneOf(print_task_1.output, print_task_2.output) + + # consume it + print_and_return(text=oneof) + + # return it + return oneof + """ + + def __new__( + cls, *channels: Union[PipelineParameterChannel, PipelineArtifactChannel] + ) -> Union[OneOfParameter, OneOfArtifact]: + first_channel = channels[0] + if isinstance(first_channel, PipelineParameterChannel): + return OneOfParameter(channels=list(channels)) + elif isinstance(first_channel, PipelineArtifactChannel): + return OneOfArtifact(channels=list(channels)) + else: + raise ValueError( + f'Got unknown input to dsl.{OneOf.__name__} with type {type(first_channel)}.' + ) + + def create_pipeline_channel( name: str, channel_type: Union[str, Dict], diff --git a/sdk/python/kfp/dsl/pipeline_channel_test.py b/sdk/python/kfp/dsl/pipeline_channel_test.py index 4de0e84a254..db4120bb768 100644 --- a/sdk/python/kfp/dsl/pipeline_channel_test.py +++ b/sdk/python/kfp/dsl/pipeline_channel_test.py @@ -13,10 +13,14 @@ # limitations under the License. """Tests for kfp.dsl.pipeline_channel.""" +from typing import List import unittest from absl.testing import parameterized from kfp import dsl +from kfp.dsl import Artifact +from kfp.dsl import Dataset +from kfp.dsl import Output from kfp.dsl import pipeline_channel @@ -156,19 +160,229 @@ def test_extract_pipeline_channels(self): self.assertListEqual([p1, p2, p3], params) +@dsl.component +def string_comp() -> str: + return 'text' + + +@dsl.component +def list_comp() -> List[str]: + return ['text'] + + +@dsl.component +def roll_three_sided_die() -> str: + import random + val = random.randint(0, 2) + + if val == 0: + return 'heads' + elif val == 1: + return 'tails' + else: + return 'draw' + + +@dsl.component +def print_and_return(text: str) -> str: + print(text) + return text + + class TestCanAccessTask(unittest.TestCase): def test(self): - @dsl.component - def comp() -> str: - return 'text' - @dsl.pipeline def my_pipeline(): - op1 = comp() + op1 = string_comp() self.assertEqual(op1.output.task, op1) +class TestOneOfAndCollectedNotComposable(unittest.TestCase): + + def test_collected_in_oneof(self): + with self.assertRaisesRegex( + ValueError, + 'dsl.Collected cannot be used inside of dsl.OneOf.'): + + @dsl.pipeline + def my_pipeline(x: str): + with dsl.If(x == 'foo'): + t1 = list_comp() + with dsl.Else(): + with dsl.ParallelFor([1, 2, 3]): + t2 = string_comp() + collected = dsl.Collected(t2.output) + # test cases doesn't return or pass to task to ensure validation is in the OneOf + dsl.OneOf(t1.output, collected) + + def test_oneof_in_collected(self): + with self.assertRaisesRegex( + ValueError, + 'dsl.OneOf cannot be used inside of dsl.Collected.'): + + @dsl.pipeline + def my_pipeline(x: str): + with dsl.ParallelFor([1, 2, 3]): + with dsl.If(x == 'foo'): + t1 = string_comp() + with dsl.Else(): + t2 = string_comp() + oneof = dsl.OneOf(t1.output, t2.output) + # test cases doesn't return or pass to task to ensure validation is in the Collected constructor + dsl.Collected(oneof) + + +class TestOneOfRequiresSameType(unittest.TestCase): + + def test_same_parameter_type(self): + + @dsl.pipeline + def my_pipeline(x: str) -> str: + with dsl.If(x == 'foo'): + t1 = string_comp() + with dsl.Else(): + t2 = string_comp() + return dsl.OneOf(t1.output, t2.output) + + self.assertEqual( + my_pipeline.pipeline_spec.components['comp-condition-branches-1'] + .output_definitions.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].parameter_type, + 3) + + def test_different_parameter_types(self): + + with self.assertRaisesRegex( + TypeError, + r'Task outputs passed to dsl\.OneOf must be the same type. Got two channels with different types: String at index 0 and typing\.List\[str\] at index 1\.' + ): + + @dsl.pipeline + def my_pipeline(x: str) -> str: + with dsl.If(x == 'foo'): + t1 = string_comp() + with dsl.Else(): + t2 = list_comp() + return dsl.OneOf(t1.output, t2.output) + + def test_same_artifact_type(self): + + @dsl.component + def artifact_comp(out: Output[Artifact]): + with open(out.path, 'w') as f: + f.write('foo') + + @dsl.pipeline + def my_pipeline(x: str) -> Artifact: + with dsl.If(x == 'foo'): + t1 = artifact_comp() + with dsl.Else(): + t2 = artifact_comp() + return dsl.OneOf(t1.outputs['out'], t2.outputs['out']) + + self.assertEqual( + my_pipeline.pipeline_spec.components['comp-condition-branches-1'] + .output_definitions + .artifacts['pipelinechannel--condition-branches-1-oneof-1'] + .artifact_type.schema_title, + 'system.Artifact', + ) + self.assertEqual( + my_pipeline.pipeline_spec.components['comp-condition-branches-1'] + .output_definitions + .artifacts['pipelinechannel--condition-branches-1-oneof-1'] + .artifact_type.schema_version, + '0.0.1', + ) + + def test_different_artifact_type(self): + + @dsl.component + def artifact_comp_one(out: Output[Artifact]): + with open(out.path, 'w') as f: + f.write('foo') + + @dsl.component + def artifact_comp_two(out: Output[Dataset]): + with open(out.path, 'w') as f: + f.write('foo') + + with self.assertRaisesRegex( + TypeError, + r'Task outputs passed to dsl\.OneOf must be the same type. Got two channels with different types: system.Artifact@0.0.1 at index 0 and system.Dataset@0.0.1 at index 1\.' + ): + + @dsl.pipeline + def my_pipeline(x: str) -> Artifact: + with dsl.If(x == 'foo'): + t1 = artifact_comp_one() + with dsl.Else(): + t2 = artifact_comp_two() + return dsl.OneOf(t1.outputs['out'], t2.outputs['out']) + + def test_different_artifact_type_due_to_list(self): + # if we ever support list of artifact outputs from components, this test will fail, which is good because it needs to be changed + + with self.assertRaisesRegex( + ValueError, + r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'out' of component 'artifact-comp-two'\." + ): + + @dsl.component + def artifact_comp_one(out: Output[Artifact]): + with open(out.path, 'w') as f: + f.write('foo') + + @dsl.component + def artifact_comp_two(out: Output[List[Artifact]]): + with open(out.path, 'w') as f: + f.write('foo') + + @dsl.pipeline + def my_pipeline(x: str) -> Artifact: + with dsl.If(x == 'foo'): + t1 = artifact_comp_one() + with dsl.Else(): + t2 = artifact_comp_two() + return dsl.OneOf(t1.outputs['out'], t2.outputs['out']) + + def test_parameters_mixed_with_artifacts(self): + + @dsl.component + def artifact_comp(out: Output[Artifact]): + with open(out.path, 'w') as f: + f.write('foo') + + with self.assertRaisesRegex( + TypeError, + r'Task outputs passed to dsl\.OneOf must be the same type\. Got two channels with different types: artifact at index 0 and parameter at index 1\.' + ): + + @dsl.pipeline + def my_pipeline(x: str) -> str: + with dsl.If(x == 'foo'): + t1 = artifact_comp() + with dsl.Else(): + t2 = string_comp() + return dsl.OneOf(t1.output, t2.output) + + def test_no_else_raises(self): + with self.assertRaisesRegex( + ValueError, + r'dsl\.OneOf must include an output from a task in a dsl\.Else group to ensure at least one output is available at runtime\.' + ): + + @dsl.pipeline + def roll_die_pipeline(): + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + print_and_return(text=dsl.OneOf(t1.output, t2.output)) + + if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/dsl/pipeline_context.py b/sdk/python/kfp/dsl/pipeline_context.py index 72ada197ae5..4881bc5680c 100644 --- a/sdk/python/kfp/dsl/pipeline_context.py +++ b/sdk/python/kfp/dsl/pipeline_context.py @@ -182,6 +182,7 @@ def push_tasks_group(self, group: 'tasks_group.TasksGroup'): group: A TasksGroup. Typically it is one of ExitHandler, Condition, and ParallelFor. """ + group.parent_task_group = self.get_parent_group() self.groups[-1].groups.append(group) self.groups.append(group) @@ -195,6 +196,9 @@ def get_last_tasks_group(self) -> Optional['tasks_group.TasksGroup']: groups = self.groups[-1].groups return groups[-1] if groups else None + def get_parent_group(self) -> 'tasks_group.TasksGroup': + return self.groups[-1] + def remove_task_from_groups(self, task: pipeline_task.PipelineTask): """Removes a task from the pipeline. diff --git a/sdk/python/kfp/dsl/tasks_group.py b/sdk/python/kfp/dsl/tasks_group.py index 2d4bb8d6932..3cfa737c392 100644 --- a/sdk/python/kfp/dsl/tasks_group.py +++ b/sdk/python/kfp/dsl/tasks_group.py @@ -68,6 +68,8 @@ def __init__( self.display_name = name self.dependencies = [] self.is_root = is_root + # backref to parent, set when the pipeline is called in pipeline_context + self.parent_task_group: Optional[TasksGroup] = None def __enter__(self): if not pipeline_context.Pipeline.get_default_pipeline(): @@ -142,6 +144,7 @@ def __init__( class ConditionBranches(TasksGroup): + _oneof_id = 0 def __init__(self) -> None: super().__init__( @@ -150,6 +153,16 @@ def __init__(self) -> None: is_root=False, ) + def _get_oneof_id(self) -> int: + """Incrementor for uniquely identifying a OneOf for the parent + ConditionBranches group. + + This is analogous to incrementing a unique identifier for tasks + groups belonging to a pipeline. + """ + self._oneof_id += 1 + return self._oneof_id + class _ConditionBase(TasksGroup): """Parent class for condition control flow context managers (Condition, If, diff --git a/sdk/python/test_data/pipelines/if_elif_else_complex.py b/sdk/python/test_data/pipelines/if_elif_else_complex.py index 45efe58cd27..ea616d9bdf1 100644 --- a/sdk/python/test_data/pipelines/if_elif_else_complex.py +++ b/sdk/python/test_data/pipelines/if_elif_else_complex.py @@ -59,18 +59,24 @@ def lucky_number_pipeline(add_drumroll: bool = True, even_or_odd_task = is_even_or_odd(num=int_task.output) with dsl.If(even_or_odd_task.output == 'even'): - print_and_return(text='Got a low even number!') + t1 = print_and_return(text='Got a low even number!') with dsl.Else(): - print_and_return(text='Got a low odd number!') + t2 = print_and_return(text='Got a low odd number!') + + repeater_task = print_and_return( + text=dsl.OneOf(t1.output, t2.output)) with dsl.Elif(int_task.output > 5000): even_or_odd_task = is_even_or_odd(num=int_task.output) with dsl.If(even_or_odd_task.output == 'even'): - print_and_return(text='Got a high even number!') + t3 = print_and_return(text='Got a high even number!') with dsl.Else(): - print_and_return(text='Got a high odd number!') + t4 = print_and_return(text='Got a high odd number!') + + repeater_task = print_and_return( + text=dsl.OneOf(t3.output, t4.output)) with dsl.Else(): print_and_return( diff --git a/sdk/python/test_data/pipelines/if_elif_else_complex.yaml b/sdk/python/test_data/pipelines/if_elif_else_complex.yaml index 9f14ee8b69f..b1f5520ba17 100644 --- a/sdk/python/test_data/pipelines/if_elif_else_complex.yaml +++ b/sdk/python/test_data/pipelines/if_elif_else_complex.yaml @@ -7,46 +7,66 @@ components: comp-condition-11: dag: + outputs: + parameters: + pipelinechannel--print-and-return-5-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return-5 tasks: - print-and-return-4: + print-and-return-5: cachingOptions: enableCache: true componentRef: - name: comp-print-and-return-4 + name: comp-print-and-return-5 inputs: parameters: text: runtimeValue: constant: Got a high even number! taskInfo: - name: print-and-return-4 + name: print-and-return-5 inputDefinitions: parameters: pipelinechannel--int-0-to-9999-Output: parameterType: NUMBER_INTEGER pipelinechannel--is-even-or-odd-2-Output: parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--print-and-return-5-Output: + parameterType: STRING comp-condition-12: dag: + outputs: + parameters: + pipelinechannel--print-and-return-6-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return-6 tasks: - print-and-return-5: + print-and-return-6: cachingOptions: enableCache: true componentRef: - name: comp-print-and-return-5 + name: comp-print-and-return-6 inputs: parameters: text: runtimeValue: constant: Got a high odd number! taskInfo: - name: print-and-return-5 + name: print-and-return-6 inputDefinitions: parameters: pipelinechannel--int-0-to-9999-Output: parameterType: NUMBER_INTEGER pipelinechannel--is-even-or-odd-2-Output: parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--print-and-return-6-Output: + parameterType: STRING comp-condition-13: dag: tasks: @@ -64,11 +84,11 @@ components: triggerPolicy: condition: inputs.parameter_values['pipelinechannel--repeat_if_lucky_number'] == true - print-and-return-6: + print-and-return-8: cachingOptions: enableCache: true componentRef: - name: comp-print-and-return-6 + name: comp-print-and-return-8 inputs: parameters: text: @@ -76,7 +96,7 @@ components: constant: 'Announcing: Got the lucky number 5000! A one in 10,000 chance.' taskInfo: - name: print-and-return-6 + name: print-and-return-8 inputDefinitions: parameters: pipelinechannel--int-0-to-9999-Output: @@ -153,6 +173,12 @@ components: parameterType: NUMBER_INTEGER comp-condition-6: dag: + outputs: + parameters: + pipelinechannel--print-and-return-2-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return-2 tasks: print-and-return-2: cachingOptions: @@ -172,8 +198,18 @@ components: parameterType: NUMBER_INTEGER pipelinechannel--is-even-or-odd-Output: parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--print-and-return-2-Output: + parameterType: STRING comp-condition-7: dag: + outputs: + parameters: + pipelinechannel--print-and-return-3-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return-3 tasks: print-and-return-3: cachingOptions: @@ -193,6 +229,10 @@ components: parameterType: NUMBER_INTEGER pipelinechannel--is-even-or-odd-Output: parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--print-and-return-3-Output: + parameterType: STRING comp-condition-8: dag: tasks: @@ -222,6 +262,21 @@ components: componentInputParameter: pipelinechannel--int-0-to-9999-Output taskInfo: name: is-even-or-odd + print-and-return-4: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-and-return-4 + dependentTasks: + - condition-branches-5 + inputs: + parameters: + text: + taskOutputParameter: + outputParameterKey: pipelinechannel--condition-branches-5-oneof-1 + producerTask: condition-branches-5 + taskInfo: + name: print-and-return-4 inputDefinitions: parameters: pipelinechannel--int-0-to-9999-Output: @@ -255,12 +310,36 @@ components: componentInputParameter: pipelinechannel--int-0-to-9999-Output taskInfo: name: is-even-or-odd-2 + print-and-return-7: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-and-return-7 + dependentTasks: + - condition-branches-10 + inputs: + parameters: + text: + taskOutputParameter: + outputParameterKey: pipelinechannel--condition-branches-10-oneof-1 + producerTask: condition-branches-10 + taskInfo: + name: print-and-return-7 inputDefinitions: parameters: pipelinechannel--int-0-to-9999-Output: parameterType: NUMBER_INTEGER comp-condition-branches-10: dag: + outputs: + parameters: + pipelinechannel--condition-branches-10-oneof-1: + valueFromOneof: + parameterSelectors: + - outputParameterKey: pipelinechannel--print-and-return-5-Output + producerSubtask: condition-11 + - outputParameterKey: pipelinechannel--print-and-return-6-Output + producerSubtask: condition-12 tasks: condition-11: componentRef: @@ -296,6 +375,10 @@ components: parameterType: NUMBER_INTEGER pipelinechannel--is-even-or-odd-2-Output: parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--condition-branches-10-oneof-1: + parameterType: STRING comp-condition-branches-4: dag: tasks: @@ -347,6 +430,15 @@ components: parameterType: BOOLEAN comp-condition-branches-5: dag: + outputs: + parameters: + pipelinechannel--condition-branches-5-oneof-1: + valueFromOneof: + parameterSelectors: + - outputParameterKey: pipelinechannel--print-and-return-2-Output + producerSubtask: condition-6 + - outputParameterKey: pipelinechannel--print-and-return-3-Output + producerSubtask: condition-7 tasks: condition-6: componentRef: @@ -382,6 +474,10 @@ components: parameterType: NUMBER_INTEGER pipelinechannel--is-even-or-odd-Output: parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--condition-branches-5-oneof-1: + parameterType: STRING comp-for-loop-1: dag: outputs: @@ -443,11 +539,11 @@ components: comp-for-loop-16: dag: tasks: - print-and-return-7: + print-and-return-9: cachingOptions: enableCache: true componentRef: - name: comp-print-and-return-7 + name: comp-print-and-return-9 inputs: parameters: text: @@ -455,7 +551,7 @@ components: constant: 'Announcing again: Got the lucky number 5000! A one in 10,000 chance.' taskInfo: - name: print-and-return-7 + name: print-and-return-9 inputDefinitions: parameters: pipelinechannel--int-0-to-9999-Output: @@ -560,6 +656,26 @@ components: parameters: Output: parameterType: STRING + comp-print-and-return-8: + executorLabel: exec-print-and-return-8 + inputDefinitions: + parameters: + text: + parameterType: STRING + outputDefinitions: + parameters: + Output: + parameterType: STRING + comp-print-and-return-9: + executorLabel: exec-print-and-return-9 + inputDefinitions: + parameters: + text: + parameterType: STRING + outputDefinitions: + parameters: + Output: + parameterType: STRING comp-print-ints: executorLabel: exec-print-ints inputDefinitions: @@ -849,6 +965,64 @@ deploymentSpec: - 'program_path=$(mktemp -d) + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef print_and_return(text: str) -> str:\n print(text)\n return\ + \ text\n\n" + image: python:3.7 + exec-print-and-return-8: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - print_and_return + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef print_and_return(text: str) -> str:\n print(text)\n return\ + \ text\n\n" + image: python:3.7 + exec-print-and-return-9: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - print_and_return + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + printf "%s" "$0" > "$program_path/ephemeral_component.py" _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" diff --git a/sdk/python/test_data/pipelines/if_elif_else.py b/sdk/python/test_data/pipelines/if_elif_else_with_oneof_parameters.py similarity index 66% rename from sdk/python/test_data/pipelines/if_elif_else.py rename to sdk/python/test_data/pipelines/if_elif_else_with_oneof_parameters.py index fdaa3428f64..7e0dc1b57fc 100644 --- a/sdk/python/test_data/pipelines/if_elif_else.py +++ b/sdk/python/test_data/pipelines/if_elif_else_with_oneof_parameters.py @@ -34,18 +34,32 @@ def print_and_return(text: str) -> str: return text +@dsl.component +def special_print_and_return(text: str, output_key: dsl.OutputPath(str)): + print('Got the special state:', text) + with open(output_key, 'w') as f: + f.write(text) + + @dsl.pipeline -def roll_die_pipeline(): +def roll_die_pipeline() -> str: flip_coin_task = flip_three_sided_die() with dsl.If(flip_coin_task.output == 'heads'): - print_and_return(text='Got heads!') + t1 = print_and_return(text='Got heads!') with dsl.Elif(flip_coin_task.output == 'tails'): - print_and_return(text='Got tails!') + t2 = print_and_return(text='Got tails!') with dsl.Else(): - print_and_return(text='Draw!') + t3 = special_print_and_return(text='Draw!') + return dsl.OneOf(t1.output, t2.output, t3.outputs['output_key']) + + +@dsl.pipeline +def outer_pipeline() -> str: + flip_coin_task = roll_die_pipeline() + return print_and_return(text=flip_coin_task.output).output if __name__ == '__main__': compiler.Compiler().compile( - pipeline_func=roll_die_pipeline, + pipeline_func=outer_pipeline, package_path=__file__.replace('.py', '.yaml')) diff --git a/sdk/python/test_data/pipelines/if_elif_else_with_oneof_parameters.yaml b/sdk/python/test_data/pipelines/if_elif_else_with_oneof_parameters.yaml new file mode 100644 index 00000000000..09159947603 --- /dev/null +++ b/sdk/python/test_data/pipelines/if_elif_else_with_oneof_parameters.yaml @@ -0,0 +1,420 @@ +# PIPELINE DEFINITION +# Name: outer-pipeline +# Outputs: +# Output: str +components: + comp-condition-2: + dag: + outputs: + parameters: + pipelinechannel--print-and-return-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return + tasks: + print-and-return: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-and-return + inputs: + parameters: + text: + runtimeValue: + constant: Got heads! + taskInfo: + name: print-and-return + inputDefinitions: + parameters: + pipelinechannel--flip-three-sided-die-Output: + parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--print-and-return-Output: + parameterType: STRING + comp-condition-3: + dag: + outputs: + parameters: + pipelinechannel--print-and-return-2-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return-2 + tasks: + print-and-return-2: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-and-return-2 + inputs: + parameters: + text: + runtimeValue: + constant: Got tails! + taskInfo: + name: print-and-return-2 + inputDefinitions: + parameters: + pipelinechannel--flip-three-sided-die-Output: + parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--print-and-return-2-Output: + parameterType: STRING + comp-condition-4: + dag: + outputs: + parameters: + pipelinechannel--special-print-and-return-output_key: + valueFromParameter: + outputParameterKey: output_key + producerSubtask: special-print-and-return + tasks: + special-print-and-return: + cachingOptions: + enableCache: true + componentRef: + name: comp-special-print-and-return + inputs: + parameters: + text: + runtimeValue: + constant: Draw! + taskInfo: + name: special-print-and-return + inputDefinitions: + parameters: + pipelinechannel--flip-three-sided-die-Output: + parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--special-print-and-return-output_key: + parameterType: STRING + comp-condition-branches-1: + dag: + outputs: + parameters: + pipelinechannel--condition-branches-1-oneof-1: + valueFromOneof: + parameterSelectors: + - outputParameterKey: pipelinechannel--print-and-return-Output + producerSubtask: condition-2 + - outputParameterKey: pipelinechannel--print-and-return-2-Output + producerSubtask: condition-3 + - outputParameterKey: pipelinechannel--special-print-and-return-output_key + producerSubtask: condition-4 + tasks: + condition-2: + componentRef: + name: comp-condition-2 + inputs: + parameters: + pipelinechannel--flip-three-sided-die-Output: + componentInputParameter: pipelinechannel--flip-three-sided-die-Output + taskInfo: + name: condition-2 + triggerPolicy: + condition: inputs.parameter_values['pipelinechannel--flip-three-sided-die-Output'] + == 'heads' + condition-3: + componentRef: + name: comp-condition-3 + inputs: + parameters: + pipelinechannel--flip-three-sided-die-Output: + componentInputParameter: pipelinechannel--flip-three-sided-die-Output + taskInfo: + name: condition-3 + triggerPolicy: + condition: '!(inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] + == ''heads'') && inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] + == ''tails''' + condition-4: + componentRef: + name: comp-condition-4 + inputs: + parameters: + pipelinechannel--flip-three-sided-die-Output: + componentInputParameter: pipelinechannel--flip-three-sided-die-Output + taskInfo: + name: condition-4 + triggerPolicy: + condition: '!(inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] + == ''heads'') && !(inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] + == ''tails'')' + inputDefinitions: + parameters: + pipelinechannel--flip-three-sided-die-Output: + parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--condition-branches-1-oneof-1: + parameterType: STRING + comp-flip-three-sided-die: + executorLabel: exec-flip-three-sided-die + outputDefinitions: + parameters: + Output: + parameterType: STRING + comp-print-and-return: + executorLabel: exec-print-and-return + inputDefinitions: + parameters: + text: + parameterType: STRING + outputDefinitions: + parameters: + Output: + parameterType: STRING + comp-print-and-return-2: + executorLabel: exec-print-and-return-2 + inputDefinitions: + parameters: + text: + parameterType: STRING + outputDefinitions: + parameters: + Output: + parameterType: STRING + comp-print-and-return-3: + executorLabel: exec-print-and-return-3 + inputDefinitions: + parameters: + text: + parameterType: STRING + outputDefinitions: + parameters: + Output: + parameterType: STRING + comp-roll-die-pipeline: + dag: + outputs: + parameters: + Output: + valueFromParameter: + outputParameterKey: pipelinechannel--condition-branches-1-oneof-1 + producerSubtask: condition-branches-1 + tasks: + condition-branches-1: + componentRef: + name: comp-condition-branches-1 + dependentTasks: + - flip-three-sided-die + inputs: + parameters: + pipelinechannel--flip-three-sided-die-Output: + taskOutputParameter: + outputParameterKey: Output + producerTask: flip-three-sided-die + taskInfo: + name: condition-branches-1 + flip-three-sided-die: + cachingOptions: + enableCache: true + componentRef: + name: comp-flip-three-sided-die + taskInfo: + name: flip-three-sided-die + outputDefinitions: + parameters: + Output: + parameterType: STRING + comp-special-print-and-return: + executorLabel: exec-special-print-and-return + inputDefinitions: + parameters: + text: + parameterType: STRING + outputDefinitions: + parameters: + output_key: + parameterType: STRING +deploymentSpec: + executors: + exec-flip-three-sided-die: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - flip_three_sided_die + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef flip_three_sided_die() -> str:\n import random\n val =\ + \ random.randint(0, 2)\n\n if val == 0:\n return 'heads'\n \ + \ elif val == 1:\n return 'tails'\n else:\n return 'draw'\n\ + \n" + image: python:3.7 + exec-print-and-return: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - print_and_return + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef print_and_return(text: str) -> str:\n print(text)\n return\ + \ text\n\n" + image: python:3.7 + exec-print-and-return-2: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - print_and_return + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef print_and_return(text: str) -> str:\n print(text)\n return\ + \ text\n\n" + image: python:3.7 + exec-print-and-return-3: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - print_and_return + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef print_and_return(text: str) -> str:\n print(text)\n return\ + \ text\n\n" + image: python:3.7 + exec-special-print-and-return: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - special_print_and_return + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef special_print_and_return(text: str, output_key: dsl.OutputPath(str)):\n\ + \ print('Got the special state:', text)\n with open(output_key, 'w')\ + \ as f:\n f.write(text)\n\n" + image: python:3.7 +pipelineInfo: + name: outer-pipeline +root: + dag: + outputs: + parameters: + Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return + tasks: + print-and-return: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-and-return-3 + dependentTasks: + - roll-die-pipeline + inputs: + parameters: + text: + taskOutputParameter: + outputParameterKey: Output + producerTask: roll-die-pipeline + taskInfo: + name: print-and-return + roll-die-pipeline: + cachingOptions: + enableCache: true + componentRef: + name: comp-roll-die-pipeline + taskInfo: + name: roll-die-pipeline + outputDefinitions: + parameters: + Output: + parameterType: STRING +schemaVersion: 2.1.0 +sdkVersion: kfp-2.3.0 diff --git a/sdk/python/test_data/pipelines/if_else.yaml b/sdk/python/test_data/pipelines/if_else.yaml deleted file mode 100644 index bdd9a8d0cb3..00000000000 --- a/sdk/python/test_data/pipelines/if_else.yaml +++ /dev/null @@ -1,214 +0,0 @@ -# PIPELINE DEFINITION -# Name: flip-coin-pipeline -components: - comp-condition-2: - dag: - tasks: - print-and-return: - cachingOptions: - enableCache: true - componentRef: - name: comp-print-and-return - inputs: - parameters: - text: - runtimeValue: - constant: Got heads! - taskInfo: - name: print-and-return - inputDefinitions: - parameters: - pipelinechannel--flip-coin-Output: - parameterType: STRING - comp-condition-3: - dag: - tasks: - print-and-return-2: - cachingOptions: - enableCache: true - componentRef: - name: comp-print-and-return-2 - inputs: - parameters: - text: - runtimeValue: - constant: Got tails! - taskInfo: - name: print-and-return-2 - inputDefinitions: - parameters: - pipelinechannel--flip-coin-Output: - parameterType: STRING - comp-condition-branches-1: - dag: - tasks: - condition-2: - componentRef: - name: comp-condition-2 - inputs: - parameters: - pipelinechannel--flip-coin-Output: - componentInputParameter: pipelinechannel--flip-coin-Output - taskInfo: - name: condition-2 - triggerPolicy: - condition: inputs.parameter_values['pipelinechannel--flip-coin-Output'] - == 'heads' - condition-3: - componentRef: - name: comp-condition-3 - inputs: - parameters: - pipelinechannel--flip-coin-Output: - componentInputParameter: pipelinechannel--flip-coin-Output - taskInfo: - name: condition-3 - triggerPolicy: - condition: '!(inputs.parameter_values[''pipelinechannel--flip-coin-Output''] - == ''heads'')' - inputDefinitions: - parameters: - pipelinechannel--flip-coin-Output: - parameterType: STRING - comp-flip-coin: - executorLabel: exec-flip-coin - outputDefinitions: - parameters: - Output: - parameterType: STRING - comp-print-and-return: - executorLabel: exec-print-and-return - inputDefinitions: - parameters: - text: - parameterType: STRING - outputDefinitions: - parameters: - Output: - parameterType: STRING - comp-print-and-return-2: - executorLabel: exec-print-and-return-2 - inputDefinitions: - parameters: - text: - parameterType: STRING - outputDefinitions: - parameters: - Output: - parameterType: STRING -deploymentSpec: - executors: - exec-flip-coin: - container: - args: - - --executor_input - - '{{$}}' - - --function_to_execute - - flip_coin - command: - - sh - - -c - - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ - \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ - \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ - \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ - $0\" \"$@\"\n" - - sh - - -ec - - 'program_path=$(mktemp -d) - - - printf "%s" "$0" > "$program_path/ephemeral_component.py" - - _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" - - ' - - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ - \ *\n\ndef flip_coin() -> str:\n import random\n return 'heads' if\ - \ random.randint(0, 1) == 0 else 'tails'\n\n" - image: python:3.7 - exec-print-and-return: - container: - args: - - --executor_input - - '{{$}}' - - --function_to_execute - - print_and_return - command: - - sh - - -c - - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ - \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ - \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ - \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ - $0\" \"$@\"\n" - - sh - - -ec - - 'program_path=$(mktemp -d) - - - printf "%s" "$0" > "$program_path/ephemeral_component.py" - - _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" - - ' - - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ - \ *\n\ndef print_and_return(text: str) -> str:\n print(text)\n return\ - \ text\n\n" - image: python:3.7 - exec-print-and-return-2: - container: - args: - - --executor_input - - '{{$}}' - - --function_to_execute - - print_and_return - command: - - sh - - -c - - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ - \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ - \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ - \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ - $0\" \"$@\"\n" - - sh - - -ec - - 'program_path=$(mktemp -d) - - - printf "%s" "$0" > "$program_path/ephemeral_component.py" - - _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" - - ' - - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ - \ *\n\ndef print_and_return(text: str) -> str:\n print(text)\n return\ - \ text\n\n" - image: python:3.7 -pipelineInfo: - name: flip-coin-pipeline -root: - dag: - tasks: - condition-branches-1: - componentRef: - name: comp-condition-branches-1 - dependentTasks: - - flip-coin - inputs: - parameters: - pipelinechannel--flip-coin-Output: - taskOutputParameter: - outputParameterKey: Output - producerTask: flip-coin - taskInfo: - name: condition-branches-1 - flip-coin: - cachingOptions: - enableCache: true - componentRef: - name: comp-flip-coin - taskInfo: - name: flip-coin -schemaVersion: 2.1.0 -sdkVersion: kfp-2.3.0 diff --git a/sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.py b/sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.py new file mode 100644 index 00000000000..4dc549c1738 --- /dev/null +++ b/sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.py @@ -0,0 +1,60 @@ +# Copyright 2023 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from kfp import dsl +from kfp.dsl import Artifact +from kfp.dsl import Input +from kfp.dsl import Output + + +@dsl.component +def flip_coin() -> str: + import random + return 'heads' if random.randint(0, 1) == 0 else 'tails' + + +@dsl.component +def param_to_artifact(val: str, a: Output[Artifact]): + with open(a.path, 'w') as f: + f.write(val) + + +@dsl.component +def print_artifact(a: Input[Artifact]): + with open(a.path) as f: + print(f.read()) + + +@dsl.pipeline +def flip_coin_pipeline() -> Artifact: + flip_coin_task = flip_coin() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = param_to_artifact(val=flip_coin_task.output) + with dsl.Else(): + t2 = param_to_artifact(val=flip_coin_task.output) + oneof = dsl.OneOf(t1.outputs['a'], t2.outputs['a']) + print_artifact(a=oneof) + return oneof + + +@dsl.pipeline +def outer_pipeline(): + flip_coin_task = flip_coin_pipeline() + print_artifact(a=flip_coin_task.output) + + +if __name__ == '__main__': + from kfp import compiler + compiler.Compiler().compile( + pipeline_func=outer_pipeline, + package_path=__file__.replace('.py', '.yaml')) diff --git a/sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.yaml b/sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.yaml new file mode 100644 index 00000000000..89e2a659fa3 --- /dev/null +++ b/sdk/python/test_data/pipelines/if_else_with_oneof_artifacts.yaml @@ -0,0 +1,380 @@ +# PIPELINE DEFINITION +# Name: outer-pipeline +components: + comp-condition-2: + dag: + outputs: + artifacts: + pipelinechannel--param-to-artifact-a: + artifactSelectors: + - outputArtifactKey: a + producerSubtask: param-to-artifact + tasks: + param-to-artifact: + cachingOptions: + enableCache: true + componentRef: + name: comp-param-to-artifact + inputs: + parameters: + val: + componentInputParameter: pipelinechannel--flip-coin-Output + taskInfo: + name: param-to-artifact + inputDefinitions: + parameters: + pipelinechannel--flip-coin-Output: + parameterType: STRING + outputDefinitions: + artifacts: + pipelinechannel--param-to-artifact-a: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 + comp-condition-3: + dag: + outputs: + artifacts: + pipelinechannel--param-to-artifact-2-a: + artifactSelectors: + - outputArtifactKey: a + producerSubtask: param-to-artifact-2 + tasks: + param-to-artifact-2: + cachingOptions: + enableCache: true + componentRef: + name: comp-param-to-artifact-2 + inputs: + parameters: + val: + componentInputParameter: pipelinechannel--flip-coin-Output + taskInfo: + name: param-to-artifact-2 + inputDefinitions: + parameters: + pipelinechannel--flip-coin-Output: + parameterType: STRING + outputDefinitions: + artifacts: + pipelinechannel--param-to-artifact-2-a: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 + comp-condition-branches-1: + dag: + outputs: + artifacts: + pipelinechannel--condition-branches-1-oneof-1: + artifactSelectors: + - outputArtifactKey: pipelinechannel--param-to-artifact-a + producerSubtask: condition-2 + - outputArtifactKey: pipelinechannel--param-to-artifact-2-a + producerSubtask: condition-3 + tasks: + condition-2: + componentRef: + name: comp-condition-2 + inputs: + parameters: + pipelinechannel--flip-coin-Output: + componentInputParameter: pipelinechannel--flip-coin-Output + taskInfo: + name: condition-2 + triggerPolicy: + condition: inputs.parameter_values['pipelinechannel--flip-coin-Output'] + == 'heads' + condition-3: + componentRef: + name: comp-condition-3 + inputs: + parameters: + pipelinechannel--flip-coin-Output: + componentInputParameter: pipelinechannel--flip-coin-Output + taskInfo: + name: condition-3 + triggerPolicy: + condition: '!(inputs.parameter_values[''pipelinechannel--flip-coin-Output''] + == ''heads'')' + inputDefinitions: + parameters: + pipelinechannel--flip-coin-Output: + parameterType: STRING + outputDefinitions: + artifacts: + pipelinechannel--condition-branches-1-oneof-1: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 + comp-flip-coin: + executorLabel: exec-flip-coin + outputDefinitions: + parameters: + Output: + parameterType: STRING + comp-flip-coin-pipeline: + dag: + outputs: + artifacts: + Output: + artifactSelectors: + - outputArtifactKey: pipelinechannel--condition-branches-1-oneof-1 + producerSubtask: condition-branches-1 + tasks: + condition-branches-1: + componentRef: + name: comp-condition-branches-1 + dependentTasks: + - flip-coin + inputs: + parameters: + pipelinechannel--flip-coin-Output: + taskOutputParameter: + outputParameterKey: Output + producerTask: flip-coin + taskInfo: + name: condition-branches-1 + flip-coin: + cachingOptions: + enableCache: true + componentRef: + name: comp-flip-coin + taskInfo: + name: flip-coin + print-artifact: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-artifact + dependentTasks: + - condition-branches-1 + inputs: + artifacts: + a: + taskOutputArtifact: + outputArtifactKey: pipelinechannel--condition-branches-1-oneof-1 + producerTask: condition-branches-1 + taskInfo: + name: print-artifact + outputDefinitions: + artifacts: + Output: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 + comp-param-to-artifact: + executorLabel: exec-param-to-artifact + inputDefinitions: + parameters: + val: + parameterType: STRING + outputDefinitions: + artifacts: + a: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 + comp-param-to-artifact-2: + executorLabel: exec-param-to-artifact-2 + inputDefinitions: + parameters: + val: + parameterType: STRING + outputDefinitions: + artifacts: + a: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 + comp-print-artifact: + executorLabel: exec-print-artifact + inputDefinitions: + artifacts: + a: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 + comp-print-artifact-2: + executorLabel: exec-print-artifact-2 + inputDefinitions: + artifacts: + a: + artifactType: + schemaTitle: system.Artifact + schemaVersion: 0.0.1 +deploymentSpec: + executors: + exec-flip-coin: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - flip_coin + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef flip_coin() -> str:\n import random\n return 'heads' if\ + \ random.randint(0, 1) == 0 else 'tails'\n\n" + image: python:3.7 + exec-param-to-artifact: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - param_to_artifact + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef param_to_artifact(val: str, a: Output[Artifact]):\n with open(a.path,\ + \ 'w') as f:\n f.write(val)\n\n" + image: python:3.7 + exec-param-to-artifact-2: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - param_to_artifact + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef param_to_artifact(val: str, a: Output[Artifact]):\n with open(a.path,\ + \ 'w') as f:\n f.write(val)\n\n" + image: python:3.7 + exec-print-artifact: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - print_artifact + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef print_artifact(a: Input[Artifact]):\n with open(a.path) as\ + \ f:\n print(f.read())\n\n" + image: python:3.7 + exec-print-artifact-2: + container: + args: + - --executor_input + - '{{$}}' + - --function_to_execute + - print_artifact + command: + - sh + - -c + - "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ + \ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ + \ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.3.0'\ + \ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ + $0\" \"$@\"\n" + - sh + - -ec + - 'program_path=$(mktemp -d) + + + printf "%s" "$0" > "$program_path/ephemeral_component.py" + + _KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" + + ' + - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ + \ *\n\ndef print_artifact(a: Input[Artifact]):\n with open(a.path) as\ + \ f:\n print(f.read())\n\n" + image: python:3.7 +pipelineInfo: + name: outer-pipeline +root: + dag: + tasks: + flip-coin-pipeline: + cachingOptions: + enableCache: true + componentRef: + name: comp-flip-coin-pipeline + taskInfo: + name: flip-coin-pipeline + print-artifact: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-artifact-2 + dependentTasks: + - flip-coin-pipeline + inputs: + artifacts: + a: + taskOutputArtifact: + outputArtifactKey: Output + producerTask: flip-coin-pipeline + taskInfo: + name: print-artifact +schemaVersion: 2.1.0 +sdkVersion: kfp-2.3.0 diff --git a/sdk/python/test_data/pipelines/if_else.py b/sdk/python/test_data/pipelines/if_else_with_oneof_parameters.py similarity index 79% rename from sdk/python/test_data/pipelines/if_else.py rename to sdk/python/test_data/pipelines/if_else_with_oneof_parameters.py index 1da8a074ac1..05f7f93403f 100644 --- a/sdk/python/test_data/pipelines/if_else.py +++ b/sdk/python/test_data/pipelines/if_else_with_oneof_parameters.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from kfp import compiler from kfp import dsl @@ -28,15 +27,19 @@ def print_and_return(text: str) -> str: @dsl.pipeline -def flip_coin_pipeline(): +def flip_coin_pipeline() -> str: flip_coin_task = flip_coin() with dsl.If(flip_coin_task.output == 'heads'): - print_and_return(text='Got heads!') + print_task_1 = print_and_return(text='Got heads!') with dsl.Else(): - print_and_return(text='Got tails!') + print_task_2 = print_and_return(text='Got tails!') + x = dsl.OneOf(print_task_1.output, print_task_2.output) + print_and_return(text=x) + return x if __name__ == '__main__': + from kfp import compiler compiler.Compiler().compile( pipeline_func=flip_coin_pipeline, package_path=__file__.replace('.py', '.yaml')) diff --git a/sdk/python/test_data/pipelines/if_elif_else.yaml b/sdk/python/test_data/pipelines/if_else_with_oneof_parameters.yaml similarity index 72% rename from sdk/python/test_data/pipelines/if_elif_else.yaml rename to sdk/python/test_data/pipelines/if_else_with_oneof_parameters.yaml index 3887ce09a97..873288dd7e4 100644 --- a/sdk/python/test_data/pipelines/if_elif_else.yaml +++ b/sdk/python/test_data/pipelines/if_else_with_oneof_parameters.yaml @@ -1,8 +1,16 @@ # PIPELINE DEFINITION -# Name: roll-die-pipeline +# Name: flip-coin-pipeline +# Outputs: +# Output: str components: comp-condition-2: dag: + outputs: + parameters: + pipelinechannel--print-and-return-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return tasks: print-and-return: cachingOptions: @@ -18,10 +26,20 @@ components: name: print-and-return inputDefinitions: parameters: - pipelinechannel--flip-three-sided-die-Output: + pipelinechannel--flip-coin-Output: + parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--print-and-return-Output: parameterType: STRING comp-condition-3: dag: + outputs: + parameters: + pipelinechannel--print-and-return-2-Output: + valueFromParameter: + outputParameterKey: Output + producerSubtask: print-and-return-2 tasks: print-and-return-2: cachingOptions: @@ -37,74 +55,58 @@ components: name: print-and-return-2 inputDefinitions: parameters: - pipelinechannel--flip-three-sided-die-Output: + pipelinechannel--flip-coin-Output: parameterType: STRING - comp-condition-4: - dag: - tasks: - print-and-return-3: - cachingOptions: - enableCache: true - componentRef: - name: comp-print-and-return-3 - inputs: - parameters: - text: - runtimeValue: - constant: Draw! - taskInfo: - name: print-and-return-3 - inputDefinitions: + outputDefinitions: parameters: - pipelinechannel--flip-three-sided-die-Output: + pipelinechannel--print-and-return-2-Output: parameterType: STRING comp-condition-branches-1: dag: + outputs: + parameters: + pipelinechannel--condition-branches-1-oneof-1: + valueFromOneof: + parameterSelectors: + - outputParameterKey: pipelinechannel--print-and-return-Output + producerSubtask: condition-2 + - outputParameterKey: pipelinechannel--print-and-return-2-Output + producerSubtask: condition-3 tasks: condition-2: componentRef: name: comp-condition-2 inputs: parameters: - pipelinechannel--flip-three-sided-die-Output: - componentInputParameter: pipelinechannel--flip-three-sided-die-Output + pipelinechannel--flip-coin-Output: + componentInputParameter: pipelinechannel--flip-coin-Output taskInfo: name: condition-2 triggerPolicy: - condition: inputs.parameter_values['pipelinechannel--flip-three-sided-die-Output'] + condition: inputs.parameter_values['pipelinechannel--flip-coin-Output'] == 'heads' condition-3: componentRef: name: comp-condition-3 inputs: parameters: - pipelinechannel--flip-three-sided-die-Output: - componentInputParameter: pipelinechannel--flip-three-sided-die-Output + pipelinechannel--flip-coin-Output: + componentInputParameter: pipelinechannel--flip-coin-Output taskInfo: name: condition-3 triggerPolicy: - condition: '!(inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] - == ''heads'') && inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] - == ''tails''' - condition-4: - componentRef: - name: comp-condition-4 - inputs: - parameters: - pipelinechannel--flip-three-sided-die-Output: - componentInputParameter: pipelinechannel--flip-three-sided-die-Output - taskInfo: - name: condition-4 - triggerPolicy: - condition: '!(inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] - == ''heads'') && !(inputs.parameter_values[''pipelinechannel--flip-three-sided-die-Output''] - == ''tails'')' + condition: '!(inputs.parameter_values[''pipelinechannel--flip-coin-Output''] + == ''heads'')' inputDefinitions: parameters: - pipelinechannel--flip-three-sided-die-Output: + pipelinechannel--flip-coin-Output: + parameterType: STRING + outputDefinitions: + parameters: + pipelinechannel--condition-branches-1-oneof-1: parameterType: STRING - comp-flip-three-sided-die: - executorLabel: exec-flip-three-sided-die + comp-flip-coin: + executorLabel: exec-flip-coin outputDefinitions: parameters: Output: @@ -141,13 +143,13 @@ components: parameterType: STRING deploymentSpec: executors: - exec-flip-three-sided-die: + exec-flip-coin: container: args: - --executor_input - '{{$}}' - --function_to_execute - - flip_three_sided_die + - flip_coin command: - sh - -c @@ -167,10 +169,8 @@ deploymentSpec: ' - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ - \ *\n\ndef flip_three_sided_die() -> str:\n import random\n val =\ - \ random.randint(0, 2)\n\n if val == 0:\n return 'heads'\n \ - \ elif val == 1:\n return 'tails'\n else:\n return 'draw'\n\ - \n" + \ *\n\ndef flip_coin() -> str:\n import random\n return 'heads' if\ + \ random.randint(0, 1) == 0 else 'tails'\n\n" image: python:3.7 exec-print-and-return: container: @@ -260,29 +260,54 @@ deploymentSpec: \ text\n\n" image: python:3.7 pipelineInfo: - name: roll-die-pipeline + name: flip-coin-pipeline root: dag: + outputs: + parameters: + Output: + valueFromParameter: + outputParameterKey: pipelinechannel--condition-branches-1-oneof-1 + producerSubtask: condition-branches-1 tasks: condition-branches-1: componentRef: name: comp-condition-branches-1 dependentTasks: - - flip-three-sided-die + - flip-coin inputs: parameters: - pipelinechannel--flip-three-sided-die-Output: + pipelinechannel--flip-coin-Output: taskOutputParameter: outputParameterKey: Output - producerTask: flip-three-sided-die + producerTask: flip-coin taskInfo: name: condition-branches-1 - flip-three-sided-die: + flip-coin: cachingOptions: enableCache: true componentRef: - name: comp-flip-three-sided-die + name: comp-flip-coin + taskInfo: + name: flip-coin + print-and-return-3: + cachingOptions: + enableCache: true + componentRef: + name: comp-print-and-return-3 + dependentTasks: + - condition-branches-1 + inputs: + parameters: + text: + taskOutputParameter: + outputParameterKey: pipelinechannel--condition-branches-1-oneof-1 + producerTask: condition-branches-1 taskInfo: - name: flip-three-sided-die + name: print-and-return-3 + outputDefinitions: + parameters: + Output: + parameterType: STRING schemaVersion: 2.1.0 sdkVersion: kfp-2.3.0 diff --git a/sdk/python/test_data/test_data_config.yaml b/sdk/python/test_data/test_data_config.yaml index 42e12c7c790..d64d7a1aea3 100644 --- a/sdk/python/test_data/test_data_config.yaml +++ b/sdk/python/test_data/test_data_config.yaml @@ -168,15 +168,18 @@ pipelines: - module: pipeline_with_metadata_fields name: dataset_concatenator execute: false - - module: if_else - name: flip_coin_pipeline - execute: false - - module: if_elif_else - name: roll_die_pipeline + - module: if_else_with_oneof_artifacts + name: outer_pipeline execute: false - module: if_elif_else_complex name: lucky_number_pipeline execute: false + - module: if_else_with_oneof_parameters + name: flip_coin_pipeline + execute: false + - module: if_elif_else_with_oneof_parameters + name: outer_pipeline + execute: false components: test_data_dir: sdk/python/test_data/components read: true From f9148936fffe07af6c5084b281ae982389c7194c Mon Sep 17 00:00:00 2001 From: connor-mccarthy Date: Sat, 14 Oct 2023 07:54:36 -0700 Subject: [PATCH 2/3] address review feedback --- sdk/python/kfp/compiler/compiler_test.py | 492 ++++++++++------------- 1 file changed, 207 insertions(+), 285 deletions(-) diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 1eb53de15dc..bb3b77e294b 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -4553,13 +4553,10 @@ class TestDslOneOf(unittest.TestCase): # Data type validation (e.g., dsl.OneOf(artifact, param) fails) and similar is covered in pipeline_channel_test.py. - # To help narrow the tests further (we already test lots of aspects in the following cases), we choose focus on dsl.OneOf behavior, not the conditional logic if If/Elif/Else. This is more verbose, but more maintainable and the behavior under test is clearer. + # To help narrow the tests further (we already test lots of aspects in the following cases), we choose focus on the dsl.OneOf behavior, not the conditional logic if If/Elif/Else. This is more verbose, but more maintainable and the behavior under test is clearer. def test_if_else_returned(self): - # if/else - # returned - # parameters - # different output keys + """Uses If and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels.""" @dsl.pipeline def roll_die_pipeline() -> str: @@ -4596,39 +4593,31 @@ def roll_die_pipeline() -> str: parameter_selectors = roll_die_pipeline.pipeline_spec.components[ 'comp-condition-branches-1'].dag.outputs.parameters[ 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( - parameter_selectors[0].output_parameter_key, - 'pipelinechannel--print-and-return-Output', - ) - self.assertEqual( - parameter_selectors[0].producer_subtask, - 'condition-2', - ) - self.assertEqual( - parameter_selectors[1].output_parameter_key, - 'pipelinechannel--print-and-return-with-output-key-output_key', - ) + parameter_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-Output', + producer_subtask='condition-2', + )) self.assertEqual( - parameter_selectors[1].producer_subtask, - 'condition-3', - ) + parameter_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-with-output-key-output_key', + producer_subtask='condition-3', + )) # surfaced as output self.assertEqual( roll_die_pipeline.pipeline_spec.root.dag.outputs - .parameters['Output'].value_from_parameter.producer_subtask, - 'condition-branches-1', - ) - self.assertEqual( - roll_die_pipeline.pipeline_spec.root.dag.outputs - .parameters['Output'].value_from_parameter.output_parameter_key, - 'pipelinechannel--condition-branches-1-oneof-1', + .parameters['Output'].value_from_parameter, + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + producer_subtask='condition-branches-1', + output_parameter_key='pipelinechannel--condition-branches-1-oneof-1', + ), ) def test_if_elif_else_returned(self): - # if/elif/else - # returned - # parameters - # different output keys + """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels.""" @dsl.pipeline def roll_die_pipeline() -> str: @@ -4675,46 +4664,35 @@ def roll_die_pipeline() -> str: 'comp-condition-branches-1'].dag.outputs.parameters[ 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors self.assertEqual( - parameter_selectors[0].output_parameter_key, - 'pipelinechannel--print-and-return-Output', - ) - self.assertEqual( - parameter_selectors[0].producer_subtask, - 'condition-2', - ) - self.assertEqual( - parameter_selectors[1].output_parameter_key, - 'pipelinechannel--print-and-return-2-Output', - ) - self.assertEqual( - parameter_selectors[1].producer_subtask, - 'condition-3', - ) - self.assertEqual( - parameter_selectors[2].output_parameter_key, - 'pipelinechannel--print-and-return-with-output-key-output_key', - ) - self.assertEqual( - parameter_selectors[2].producer_subtask, - 'condition-4', - ) + parameter_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-Output', + producer_subtask='condition-2', + )) + self.assertEqual( + parameter_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-2-Output', + producer_subtask='condition-3', + )) + self.assertEqual( + parameter_selectors[2], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-with-output-key-output_key', + producer_subtask='condition-4', + )) # surfaced as output self.assertEqual( roll_die_pipeline.pipeline_spec.root.dag.outputs - .parameters['Output'].value_from_parameter.producer_subtask, - 'condition-branches-1', - ) - self.assertEqual( - roll_die_pipeline.pipeline_spec.root.dag.outputs - .parameters['Output'].value_from_parameter.output_parameter_key, - 'pipelinechannel--condition-branches-1-oneof-1', + .parameters['Output'].value_from_parameter, + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + producer_subtask='condition-branches-1', + output_parameter_key='pipelinechannel--condition-branches-1-oneof-1', + ), ) def test_if_elif_else_consumed(self): - # tests if/elif/else - # returned - # parameters - # different output keys + """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf passed to a consumer task, and different output keys on dsl.OneOf channels.""" @dsl.pipeline def roll_die_pipeline(): @@ -4762,47 +4740,36 @@ def roll_die_pipeline(): 'comp-condition-branches-1'].dag.outputs.parameters[ 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors self.assertEqual( - parameter_selectors[0].output_parameter_key, - 'pipelinechannel--print-and-return-Output', - ) - self.assertEqual( - parameter_selectors[0].producer_subtask, - 'condition-2', - ) - self.assertEqual( - parameter_selectors[1].output_parameter_key, - 'pipelinechannel--print-and-return-2-Output', - ) - self.assertEqual( - parameter_selectors[1].producer_subtask, - 'condition-3', - ) - self.assertEqual( - parameter_selectors[2].output_parameter_key, - 'pipelinechannel--print-and-return-with-output-key-output_key', - ) - self.assertEqual( - parameter_selectors[2].producer_subtask, - 'condition-4', - ) + parameter_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-Output', + producer_subtask='condition-2', + )) + self.assertEqual( + parameter_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-2-Output', + producer_subtask='condition-3', + )) + self.assertEqual( + parameter_selectors[2], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-with-output-key-output_key', + producer_subtask='condition-4', + )) # consumed from condition-branches self.assertEqual( roll_die_pipeline.pipeline_spec.root.dag.tasks['print-and-return-3'] - .inputs.parameters['text'].task_output_parameter.producer_task, - 'condition-branches-1', - ) - self.assertEqual( - roll_die_pipeline.pipeline_spec.root.dag.tasks['print-and-return-3'] - .inputs.parameters['text'].task_output_parameter - .output_parameter_key, - 'pipelinechannel--condition-branches-1-oneof-1', + .inputs.parameters['text'].task_output_parameter, + pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec + .TaskOutputParameterSpec( + producer_task='condition-branches-1', + output_parameter_key='pipelinechannel--condition-branches-1-oneof-1', + ), ) def test_if_else_consumed_and_returned(self): - # tests if/else - # consumed and returned - # parameters - # same output key + """Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline.""" @dsl.pipeline def flip_coin_pipeline() -> str: @@ -4841,52 +4808,41 @@ def flip_coin_pipeline() -> str: 'comp-condition-branches-1'].dag.outputs.parameters[ 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors self.assertEqual( - parameter_selectors[0].output_parameter_key, - 'pipelinechannel--print-and-return-Output', - ) - self.assertEqual( - parameter_selectors[0].producer_subtask, - 'condition-2', - ) - self.assertEqual( - parameter_selectors[1].output_parameter_key, - 'pipelinechannel--print-and-return-2-Output', - ) - self.assertEqual( - parameter_selectors[1].producer_subtask, - 'condition-3', - ) + parameter_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-Output', + producer_subtask='condition-2', + )) + self.assertEqual( + parameter_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-2-Output', + producer_subtask='condition-3', + )) # consumed from condition-branches self.assertEqual( flip_coin_pipeline.pipeline_spec.root.dag .tasks['print-and-return-3'].inputs.parameters['text'] - .task_output_parameter.producer_task, - 'condition-branches-1', - ) - self.assertEqual( - flip_coin_pipeline.pipeline_spec.root.dag - .tasks['print-and-return-3'].inputs.parameters['text'] - .task_output_parameter.output_parameter_key, - 'pipelinechannel--condition-branches-1-oneof-1', + .task_output_parameter, + pipeline_spec_pb2.TaskInputsSpec.InputParameterSpec + .TaskOutputParameterSpec( + producer_task='condition-branches-1', + output_parameter_key='pipelinechannel--condition-branches-1-oneof-1', + ), ) # surfaced as output self.assertEqual( flip_coin_pipeline.pipeline_spec.root.dag.outputs - .parameters['Output'].value_from_parameter.producer_subtask, - 'condition-branches-1', - ) - self.assertEqual( - flip_coin_pipeline.pipeline_spec.root.dag.outputs - .parameters['Output'].value_from_parameter.output_parameter_key, - 'pipelinechannel--condition-branches-1-oneof-1', + .parameters['Output'].value_from_parameter, + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + producer_subtask='condition-branches-1', + output_parameter_key='pipelinechannel--condition-branches-1-oneof-1', + ), ) def test_if_else_consumed_and_returned_artifacts(self): - # tests if/else - # consumed and returned - # artifacts - # same output key + """Uses If, Elif, and Else branches, artifacts passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline.""" @dsl.pipeline def flip_coin_pipeline() -> Artifact: @@ -4927,48 +4883,41 @@ def flip_coin_pipeline() -> Artifact: 'comp-condition-branches-1'].dag.outputs.artifacts[ 'pipelinechannel--condition-branches-1-oneof-1'].artifact_selectors self.assertEqual( - artifact_selectors[0].output_artifact_key, - 'pipelinechannel--print-and-return-as-artifact-a', - ) + artifact_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec( + output_artifact_key='pipelinechannel--print-and-return-as-artifact-a', + producer_subtask='condition-2', + )) self.assertEqual( - artifact_selectors[0].producer_subtask, - 'condition-2', - ) - self.assertEqual( - artifact_selectors[1].output_artifact_key, - 'pipelinechannel--print-and-return-as-artifact-2-a', - ) - self.assertEqual( - artifact_selectors[1].producer_subtask, - 'condition-3', - ) + artifact_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec( + output_artifact_key='pipelinechannel--print-and-return-as-artifact-2-a', + producer_subtask='condition-3', + )) + # consumed from condition-branches self.assertEqual( flip_coin_pipeline.pipeline_spec.root.dag.tasks['print-artifact'] - .inputs.artifacts['a'].task_output_artifact.producer_task, - 'condition-branches-1', - ) - self.assertEqual( - flip_coin_pipeline.pipeline_spec.root.dag.tasks['print-artifact'] - .inputs.artifacts['a'].task_output_artifact.output_artifact_key, - 'pipelinechannel--condition-branches-1-oneof-1', + .inputs.artifacts['a'].task_output_artifact, + pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec + .TaskOutputArtifactSpec( + producer_task='condition-branches-1', + output_artifact_key='pipelinechannel--condition-branches-1-oneof-1', + ), ) # surfaced as output self.assertEqual( flip_coin_pipeline.pipeline_spec.root.dag.outputs - .artifacts['Output'].artifact_selectors[0].producer_subtask, - 'condition-branches-1', - ) - self.assertEqual( - flip_coin_pipeline.pipeline_spec.root.dag.outputs - .artifacts['Output'].artifact_selectors[0].output_artifact_key, - 'pipelinechannel--condition-branches-1-oneof-1', + .artifacts['Output'].artifact_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec( + producer_subtask='condition-branches-1', + output_artifact_key='pipelinechannel--condition-branches-1-oneof-1', + ), ) def test_nested_under_condition_consumed(self): - # nested under loop and condition - # artifact + """Uses If, Else, and OneOf nested under a parent If.""" @dsl.pipeline def flip_coin_pipeline(execute_pipeline: bool): @@ -5012,38 +4961,29 @@ def flip_coin_pipeline(execute_pipeline: bool): 'comp-condition-branches-2'].dag.outputs.artifacts[ 'pipelinechannel--condition-branches-2-oneof-1'].artifact_selectors self.assertEqual( - artifact_selectors[0].output_artifact_key, - 'pipelinechannel--print-and-return-as-artifact-a', - ) - self.assertEqual( - artifact_selectors[0].producer_subtask, - 'condition-3', - ) - self.assertEqual( - artifact_selectors[1].output_artifact_key, - 'pipelinechannel--print-and-return-as-artifact-2-a', - ) - self.assertEqual( - artifact_selectors[1].producer_subtask, - 'condition-4', - ) + artifact_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec( + output_artifact_key='pipelinechannel--print-and-return-as-artifact-a', + producer_subtask='condition-3', + )) + self.assertEqual( + artifact_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec( + output_artifact_key='pipelinechannel--print-and-return-as-artifact-2-a', + producer_subtask='condition-4', + )) # consumed from condition-branches self.assertEqual( flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag - .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact - .producer_task, - 'condition-branches-2', - ) - self.assertEqual( - flip_coin_pipeline.pipeline_spec.components['comp-condition-1'].dag - .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact - .output_artifact_key, - 'pipelinechannel--condition-branches-2-oneof-1', + .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact, + pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec + .TaskOutputArtifactSpec( + producer_task='condition-branches-2', + output_artifact_key='pipelinechannel--condition-branches-2-oneof-1', + ), ) def test_nested_under_condition_returned_raises(self): - # nested under loop and condition - # artifact with self.assertRaisesRegex( compiler_utils.InvalidTopologyException, f'Pipeline outputs may only be returned from the top level of the pipeline function scope\. Got pipeline output dsl\.OneOf from within the control flow group dsl\.If\.' @@ -5063,9 +5003,7 @@ def flip_coin_pipeline(execute_pipeline: bool): print_task_2.outputs['a']) def test_deeply_nested_consumed(self): - # nested under loop and condition and exit handler - # consumed - # artifact + """Uses If, Elif, Else, and OneOf deeply nested within multiple dub-DAGs.""" @dsl.pipeline def flip_coin_pipeline(execute_pipeline: bool): @@ -5089,21 +5027,15 @@ def flip_coin_pipeline(execute_pipeline: bool): # consumed from condition-branches self.assertEqual( flip_coin_pipeline.pipeline_spec.components['comp-condition-4'].dag - .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact - .producer_task, - 'condition-branches-5', - ) - self.assertEqual( - flip_coin_pipeline.pipeline_spec.components['comp-condition-4'].dag - .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact - .output_artifact_key, - 'pipelinechannel--condition-branches-5-oneof-1', + .tasks['print-artifact'].inputs.artifacts['a'].task_output_artifact, + pipeline_spec_pb2.TaskInputsSpec.InputArtifactSpec + .TaskOutputArtifactSpec( + producer_task='condition-branches-5', + output_artifact_key='pipelinechannel--condition-branches-5-oneof-1', + ), ) def test_deeply_nested_returned_raises(self): - # nested under loop and condition - # returned - # artifact with self.assertRaisesRegex( compiler_utils.InvalidTopologyException, @@ -5169,7 +5101,61 @@ def flip_coin_pipeline(execute_pipeline: bool): return dsl.OneOf(print_task_1.outputs['a'], print_task_2.outputs['a']) + def test_oneof_in_condition(self): + """Tests that dsl.OneOf's channel can be consumed in a downstream group nested one level""" + + @dsl.pipeline + def roll_die_pipeline(repeat_on: str = 'Got heads!'): + flip_coin_task = roll_three_sided_die() + with dsl.If(flip_coin_task.output == 'heads'): + t1 = print_and_return(text='Got heads!') + with dsl.Elif(flip_coin_task.output == 'tails'): + t2 = print_and_return(text='Got tails!') + with dsl.Else(): + t3 = print_and_return_with_output_key(text='Draw!') + x = dsl.OneOf(t1.output, t2.output, t3.outputs['output_key']) + + with dsl.If(x == repeat_on): + print_and_return(text=x) + + # condition-branches surfaces + self.assertEqual( + roll_die_pipeline.pipeline_spec + .components['comp-condition-branches-1'].output_definitions + .parameters['pipelinechannel--condition-branches-1-oneof-1'] + .parameter_type, + type_utils.STRING, + ) + parameter_selectors = roll_die_pipeline.pipeline_spec.components[ + 'comp-condition-branches-1'].dag.outputs.parameters[ + 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors + self.assertEqual( + parameter_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-Output', + producer_subtask='condition-2', + )) + self.assertEqual( + parameter_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-2-Output', + producer_subtask='condition-3', + )) + self.assertEqual( + parameter_selectors[2], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-with-output-key-output_key', + producer_subtask='condition-4', + )) + # condition points to correct upstream output + self.assertEqual( + roll_die_pipeline.pipeline_spec.root.dag.tasks['condition-5'] + .trigger_policy.condition, + "inputs.parameter_values['pipelinechannel--condition-branches-1-pipelinechannel--condition-branches-1-oneof-1'] == inputs.parameter_values['pipelinechannel--repeat_on']" + ) + def test_consumed_in_nested_groups(self): + """Tests that dsl.OneOf's channel can be consumed in a downstream group nested multiple levels""" @dsl.pipeline def roll_die_pipeline( @@ -5201,29 +5187,23 @@ def roll_die_pipeline( 'comp-condition-branches-1'].dag.outputs.parameters[ 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors self.assertEqual( - parameter_selectors[0].output_parameter_key, - 'pipelinechannel--print-and-return-Output', - ) - self.assertEqual( - parameter_selectors[0].producer_subtask, - 'condition-2', - ) - self.assertEqual( - parameter_selectors[1].output_parameter_key, - 'pipelinechannel--print-and-return-2-Output', - ) - self.assertEqual( - parameter_selectors[1].producer_subtask, - 'condition-3', - ) - self.assertEqual( - parameter_selectors[2].output_parameter_key, - 'pipelinechannel--print-and-return-with-output-key-output_key', - ) - self.assertEqual( - parameter_selectors[2].producer_subtask, - 'condition-4', - ) + parameter_selectors[0], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-Output', + producer_subtask='condition-2', + )) + self.assertEqual( + parameter_selectors[1], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-2-Output', + producer_subtask='condition-3', + )) + self.assertEqual( + parameter_selectors[2], + pipeline_spec_pb2.DagOutputsSpec.ParameterSelectorSpec( + output_parameter_key='pipelinechannel--print-and-return-with-output-key-output_key', + producer_subtask='condition-4', + )) # condition points to correct upstream output self.assertEqual( roll_die_pipeline.pipeline_spec.components['comp-condition-6'] @@ -5256,64 +5236,6 @@ def roll_die_pipeline(): text=f"Final result: {dsl.OneOf(t1.output, t2.output, t3.outputs['output_key'])}" ) - def test_oneof_in_condition(self): - - @dsl.pipeline - def roll_die_pipeline(repeat_on: str = 'Got heads!'): - flip_coin_task = roll_three_sided_die() - with dsl.If(flip_coin_task.output == 'heads'): - t1 = print_and_return(text='Got heads!') - with dsl.Elif(flip_coin_task.output == 'tails'): - t2 = print_and_return(text='Got tails!') - with dsl.Else(): - t3 = print_and_return_with_output_key(text='Draw!') - x = dsl.OneOf(t1.output, t2.output, t3.outputs['output_key']) - - with dsl.If(x == repeat_on): - print_and_return(text=x) - - # condition-branches surfaces - self.assertEqual( - roll_die_pipeline.pipeline_spec - .components['comp-condition-branches-1'].output_definitions - .parameters['pipelinechannel--condition-branches-1-oneof-1'] - .parameter_type, - type_utils.STRING, - ) - parameter_selectors = roll_die_pipeline.pipeline_spec.components[ - 'comp-condition-branches-1'].dag.outputs.parameters[ - 'pipelinechannel--condition-branches-1-oneof-1'].value_from_oneof.parameter_selectors - self.assertEqual( - parameter_selectors[0].output_parameter_key, - 'pipelinechannel--print-and-return-Output', - ) - self.assertEqual( - parameter_selectors[0].producer_subtask, - 'condition-2', - ) - self.assertEqual( - parameter_selectors[1].output_parameter_key, - 'pipelinechannel--print-and-return-2-Output', - ) - self.assertEqual( - parameter_selectors[1].producer_subtask, - 'condition-3', - ) - self.assertEqual( - parameter_selectors[2].output_parameter_key, - 'pipelinechannel--print-and-return-with-output-key-output_key', - ) - self.assertEqual( - parameter_selectors[2].producer_subtask, - 'condition-4', - ) - # condition points to correct upstream output - self.assertEqual( - roll_die_pipeline.pipeline_spec.root.dag.tasks['condition-5'] - .trigger_policy.condition, - "inputs.parameter_values['pipelinechannel--condition-branches-1-pipelinechannel--condition-branches-1-oneof-1'] == inputs.parameter_values['pipelinechannel--repeat_on']" - ) - def test_type_checking_parameters(self): with self.assertRaisesRegex( type_utils.InconsistentTypeException, From 8021e900ff0bad7cb0857926840475ef8a936c8e Mon Sep 17 00:00:00 2001 From: connor-mccarthy Date: Tue, 17 Oct 2023 18:19:44 -0700 Subject: [PATCH 3/3] address review feedback --- sdk/python/kfp/compiler/compiler_test.py | 2 +- sdk/python/kfp/dsl/pipeline_channel.py | 32 +++++++++------------ sdk/python/kfp/dsl/pipeline_channel_test.py | 2 +- sdk/python/kfp/dsl/tasks_group.py | 2 +- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index bb3b77e294b..b5d7a5267d7 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -5221,7 +5221,7 @@ def roll_die_pipeline( def test_oneof_in_fstring(self): with self.assertRaisesRegex( NotImplementedError, - f'dsl\.OneOf is not yet supported in f-strings\.'): + f'dsl\.OneOf does not support string interpolation\.'): @dsl.pipeline def roll_die_pipeline(): diff --git a/sdk/python/kfp/dsl/pipeline_channel.py b/sdk/python/kfp/dsl/pipeline_channel.py index 7f93ca38dc0..6adb52525cd 100644 --- a/sdk/python/kfp/dsl/pipeline_channel.py +++ b/sdk/python/kfp/dsl/pipeline_channel.py @@ -313,7 +313,7 @@ def _make_oneof_name(self) -> str: # In the downstream compiler logic, we get to treat this output like a # normal task output. return compiler_utils.additional_input_name_for_pipeline_channel( - f'{self.condition_branches_group.name}-oneof-{self.condition_branches_group._get_oneof_id()}' + f'{self.condition_branches_group.name}-oneof-{self.condition_branches_group.get_oneof_id()}' ) def _validate_channels( @@ -350,25 +350,19 @@ def _validate_no_mix_of_parameters_and_artifacts( self, channels: List[Union[PipelineParameterChannel, PipelineArtifactChannel]] ) -> None: - readable_name_map = { - PipelineParameterChannel: 'parameter', - PipelineArtifactChannel: 'artifact', - OneOfParameter: 'parameter', - OneOfArtifact: 'artifact', - } - # if channels[0] is any subclass of a PipelineParameterChannel - # check the rest of the channels against that parent type - # ensures check permits OneOfParameter and PipelineParameterChannel - # to be passed to OneOf together - if isinstance(channels[0], PipelineParameterChannel): - expected_type = PipelineParameterChannel + + first_channel = channels[0] + if isinstance(first_channel, PipelineParameterChannel): + first_channel_type = PipelineParameterChannel else: - expected_type = PipelineArtifactChannel + first_channel_type = PipelineArtifactChannel - for i, channel in enumerate(channels[1:], start=1): - if not isinstance(channel, expected_type): + for channel in channels: + # if not all channels match the first channel's type, then there + # is a mix of parameter and artifact channels + if not isinstance(channel, first_channel_type): raise TypeError( - f'Task outputs passed to dsl.{OneOf.__name__} must be the same type. Got two channels with different types: {readable_name_map[expected_type]} at index 0 and {readable_name_map[type(channel)]} at index {i}.' + f'Task outputs passed to dsl.{OneOf.__name__} must be the same type. Found a mix of parameters and artifacts passed to dsl.{OneOf.__name__}.' ) def _validate_has_else_group( @@ -391,7 +385,7 @@ def __str__(self): # the combination of OneOf and an f-string is not common, so prefer # deferring implementation raise NotImplementedError( - f'dsl.{OneOf.__name__} is not yet supported in f-strings.') + f'dsl.{OneOf.__name__} does not support string interpolation.') @property def pattern(self) -> str: @@ -411,7 +405,7 @@ def pattern(self) -> str: # isinstance(, PipelineParameterChannel/PipelineArtifactChannel) # checks continue to behave in desirable ways class OneOfParameter(PipelineParameterChannel, OneOfMixin): - """OneOf that results in an parameter channel for all downstream tasks.""" + """OneOf that results in a parameter channel for all downstream tasks.""" def __init__(self, channels: List[PipelineParameterChannel]) -> None: self.channels = channels diff --git a/sdk/python/kfp/dsl/pipeline_channel_test.py b/sdk/python/kfp/dsl/pipeline_channel_test.py index db4120bb768..b0b72be0830 100644 --- a/sdk/python/kfp/dsl/pipeline_channel_test.py +++ b/sdk/python/kfp/dsl/pipeline_channel_test.py @@ -357,7 +357,7 @@ def artifact_comp(out: Output[Artifact]): with self.assertRaisesRegex( TypeError, - r'Task outputs passed to dsl\.OneOf must be the same type\. Got two channels with different types: artifact at index 0 and parameter at index 1\.' + r'Task outputs passed to dsl\.OneOf must be the same type\. Found a mix of parameters and artifacts passed to dsl\.OneOf\.' ): @dsl.pipeline diff --git a/sdk/python/kfp/dsl/tasks_group.py b/sdk/python/kfp/dsl/tasks_group.py index 3cfa737c392..3f0f758bbd3 100644 --- a/sdk/python/kfp/dsl/tasks_group.py +++ b/sdk/python/kfp/dsl/tasks_group.py @@ -153,7 +153,7 @@ def __init__(self) -> None: is_root=False, ) - def _get_oneof_id(self) -> int: + def get_oneof_id(self) -> int: """Incrementor for uniquely identifying a OneOf for the parent ConditionBranches group.