From dea0823fe823fbffa224faf059a6de9bf13f4129 Mon Sep 17 00:00:00 2001 From: Chen Sun Date: Wed, 3 Nov 2021 16:34:21 -0700 Subject: [PATCH] fix(sdk): Make `Artifact` type be compatible with any sub-artifact types bidirectionally (#6859) --- sdk/python/kfp/dsl/types.py | 15 +++--- sdk/python/kfp/v2/compiler/compiler_test.py | 58 +++++++++++++++++++-- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/sdk/python/kfp/dsl/types.py b/sdk/python/kfp/dsl/types.py index e476584b6c3..152da36c6eb 100644 --- a/sdk/python/kfp/dsl/types.py +++ b/sdk/python/kfp/dsl/types.py @@ -140,12 +140,15 @@ def verify_type_compatibility(given_type: TypeSpecType, if given_type is None or expected_type is None: return True - # Generic artifacts resulted from missing type or explicit "Artifact" type can - # be passed to inputs expecting any artifact types. - # However, generic artifacts resulted from arbitrary unknown types do not have - # such "compatible" feature. - if not type_utils.is_parameter_type(str(expected_type)) and ( - given_type is None or str(given_type).lower() == "artifact"): + # Generic artifacts resulted from missing type or explicit "Artifact" type + # is compatible with any artifact types. + # However, generic artifacts resulted from arbitrary unknown types do not + # have such "compatible" feature. + if not type_utils.is_parameter_type( + str(expected_type)) and str(given_type).lower() == "artifact": + return True + if not type_utils.is_parameter_type( + str(given_type)) and str(expected_type).lower() == "artifact": return True types_are_compatible = check_types(given_type, expected_type) diff --git a/sdk/python/kfp/v2/compiler/compiler_test.py b/sdk/python/kfp/v2/compiler/compiler_test.py index 0f8e4cf4c1e..e0033bf5dd0 100644 --- a/sdk/python/kfp/v2/compiler/compiler_test.py +++ b/sdk/python/kfp/v2/compiler/compiler_test.py @@ -23,7 +23,6 @@ from kfp.v2 import dsl from kfp.dsl import types - VALID_PRODUCER_COMPONENT_SAMPLE = components.load_component_from_text(""" name: producer inputs: @@ -40,6 +39,7 @@ - {outputPath: output_value} """) + class CompilerTest(unittest.TestCase): def test_compile_simple_pipeline(self): @@ -177,10 +177,10 @@ def test_compile_pipeline_with_missing_task_should_raise_error(self): def my_pipeline(text: str): pass - with self.assertRaisesRegex( - ValueError,'Task is missing from pipeline.'): + with self.assertRaisesRegex(ValueError, + 'Task is missing from pipeline.'): compiler.Compiler().compile( - pipeline_func=my_pipeline, package_path='output.json') + pipeline_func=my_pipeline, package_path='output.json') def test_compile_pipeline_with_misused_inputuri_should_raise_error(self): @@ -353,6 +353,56 @@ def my_pipeline(): finally: shutil.rmtree(tmpdir) + def test_passing_concrete_artifact_to_input_expecting_generic_artifact( + self): + + producer_op1 = components.load_component_from_text(""" + name: producer compoent + outputs: + - {name: output, type: Dataset} + implementation: + container: + image: dummy + args: + - {outputPath: output} + """) + + @dsl.component + def producer_op2(output: dsl.Output[dsl.Model]): + pass + + consumer_op1 = components.load_component_from_text(""" + name: consumer compoent + inputs: + - {name: input, type: Artifact} + implementation: + container: + image: dummy + args: + - {inputPath: input} + """) + + @dsl.component + def consumer_op2(input: dsl.Input[dsl.Artifact]): + pass + + @dsl.pipeline(name='test-pipeline') + def my_pipeline(): + consumer_op1(producer_op1().output) + consumer_op1(producer_op2().output) + consumer_op2(producer_op1().output) + consumer_op2(producer_op2().output) + + try: + tmpdir = tempfile.mkdtemp() + target_json_file = os.path.join(tmpdir, 'result.json') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=target_json_file) + + self.assertTrue(os.path.exists(target_json_file)) + finally: + shutil.rmtree(tmpdir) + def test_passing_arbitrary_artifact_to_input_expecting_concrete_artifact( self):