Skip to content

Commit

Permalink
fix(sdk): Make Artifact type be compatible with any sub-artifact ty…
Browse files Browse the repository at this point in the history
…pes bidirectionally (#6859)
  • Loading branch information
chensun authored Nov 3, 2021
1 parent fe0865e commit dea0823
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
15 changes: 9 additions & 6 deletions sdk/python/kfp/dsl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,15 @@ def verify_type_compatibility(given_type: TypeSpecType,
if given_type is None or expected_type is None:
return True

# Generic artifacts resulted from missing type or explicit "Artifact" type can
# be passed to inputs expecting any artifact types.
# However, generic artifacts resulted from arbitrary unknown types do not have
# such "compatible" feature.
if not type_utils.is_parameter_type(str(expected_type)) and (
given_type is None or str(given_type).lower() == "artifact"):
# Generic artifacts resulted from missing type or explicit "Artifact" type
# is compatible with any artifact types.
# However, generic artifacts resulted from arbitrary unknown types do not
# have such "compatible" feature.
if not type_utils.is_parameter_type(
str(expected_type)) and str(given_type).lower() == "artifact":
return True
if not type_utils.is_parameter_type(
str(given_type)) and str(expected_type).lower() == "artifact":
return True

types_are_compatible = check_types(given_type, expected_type)
Expand Down
58 changes: 54 additions & 4 deletions sdk/python/kfp/v2/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from kfp.v2 import dsl
from kfp.dsl import types


VALID_PRODUCER_COMPONENT_SAMPLE = components.load_component_from_text("""
name: producer
inputs:
Expand All @@ -40,6 +39,7 @@
- {outputPath: output_value}
""")


class CompilerTest(unittest.TestCase):

def test_compile_simple_pipeline(self):
Expand Down Expand Up @@ -177,10 +177,10 @@ def test_compile_pipeline_with_missing_task_should_raise_error(self):
def my_pipeline(text: str):
pass

with self.assertRaisesRegex(
ValueError,'Task is missing from pipeline.'):
with self.assertRaisesRegex(ValueError,
'Task is missing from pipeline.'):
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path='output.json')
pipeline_func=my_pipeline, package_path='output.json')

def test_compile_pipeline_with_misused_inputuri_should_raise_error(self):

Expand Down Expand Up @@ -353,6 +353,56 @@ def my_pipeline():
finally:
shutil.rmtree(tmpdir)

def test_passing_concrete_artifact_to_input_expecting_generic_artifact(
self):

producer_op1 = components.load_component_from_text("""
name: producer compoent
outputs:
- {name: output, type: Dataset}
implementation:
container:
image: dummy
args:
- {outputPath: output}
""")

@dsl.component
def producer_op2(output: dsl.Output[dsl.Model]):
pass

consumer_op1 = components.load_component_from_text("""
name: consumer compoent
inputs:
- {name: input, type: Artifact}
implementation:
container:
image: dummy
args:
- {inputPath: input}
""")

@dsl.component
def consumer_op2(input: dsl.Input[dsl.Artifact]):
pass

@dsl.pipeline(name='test-pipeline')
def my_pipeline():
consumer_op1(producer_op1().output)
consumer_op1(producer_op2().output)
consumer_op2(producer_op1().output)
consumer_op2(producer_op2().output)

try:
tmpdir = tempfile.mkdtemp()
target_json_file = os.path.join(tmpdir, 'result.json')
compiler.Compiler().compile(
pipeline_func=my_pipeline, package_path=target_json_file)

self.assertTrue(os.path.exists(target_json_file))
finally:
shutil.rmtree(tmpdir)

def test_passing_arbitrary_artifact_to_input_expecting_concrete_artifact(
self):

Expand Down

0 comments on commit dea0823

Please sign in to comment.