From 41752608bc5da2d67fbc23ba8e1fa7a78e77a50c Mon Sep 17 00:00:00 2001 From: Connor McCarthy Date: Thu, 16 Feb 2023 09:26:53 -0800 Subject: [PATCH] support list of artifact input placeholders (#8484) --- sdk/python/kfp/compiler/compiler_test.py | 92 ++++++++++++++++++- .../kfp/compiler/pipeline_spec_builder.py | 12 ++- .../kfp/components/component_factory.py | 58 ++++++++---- .../kfp/components/component_factory_test.py | 28 +++--- ...ntainer_component_artifact_channel_test.py | 6 +- sdk/python/kfp/components/placeholders.py | 25 ++++- .../kfp/components/placeholders_test.py | 79 +++++++++++++++- sdk/python/kfp/components/structures.py | 10 +- 8 files changed, 266 insertions(+), 44 deletions(-) diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 92e9c10d20a..792472726c8 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -673,8 +673,8 @@ def my_pipeline(): pipeline_spec['deploymentSpec']['executors']) def test_pipeline_with_invalid_output(self): - with self.assertRaisesRegex(ValueError, - 'Pipeline output not defined: msg1'): + with self.assertRaisesRegex( + ValueError, r'Pipeline or component output not defined: msg1'): @dsl.component def print_op(msg: str) -> str: @@ -3092,5 +3092,93 @@ def comp(message: str) -> str: comp.pipeline_spec.root.output_definitions.parameters) +class TestListOfArtifactsInterfaceCompileAndLoad(unittest.TestCase): + + def test_python_component(self): + + @dsl.component + def python_component(input_list: Input[List[Artifact]]): + pass + + self.assertEqual( + python_component.pipeline_spec.root.input_definitions + .artifacts['input_list'].is_artifact_list, True) + self.assertEqual( + python_component.pipeline_spec.components['comp-python-component'] + .input_definitions.artifacts['input_list'].is_artifact_list, True) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=python_component, package_path=output_yaml) + loaded_component = components.load_component_from_file(output_yaml) + + self.assertEqual( + loaded_component.pipeline_spec.root.input_definitions + .artifacts['input_list'].is_artifact_list, True) + self.assertEqual( + loaded_component.pipeline_spec.components['comp-python-component'] + .input_definitions.artifacts['input_list'].is_artifact_list, True) + + def test_container_component(self): + + @dsl.container_component + def container_component(input_list: Input[List[Artifact]]): + return dsl.ContainerSpec( + image='alpine', command=['echo'], args=['hello world']) + + self.assertEqual( + container_component.pipeline_spec.root.input_definitions + .artifacts['input_list'].is_artifact_list, True) + self.assertEqual( + container_component.pipeline_spec + .components['comp-container-component'].input_definitions + .artifacts['input_list'].is_artifact_list, True) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=container_component, package_path=output_yaml) + loaded_component = components.load_component_from_file(output_yaml) + + self.assertEqual( + loaded_component.pipeline_spec.root.input_definitions + .artifacts['input_list'].is_artifact_list, True) + self.assertEqual( + loaded_component.pipeline_spec + .components['comp-container-component'].input_definitions + .artifacts['input_list'].is_artifact_list, True) + + def test_pipeline(self): + + @dsl.component + def python_component(input_list: Input[List[Artifact]]): + pass + + @dsl.pipeline + def pipeline_component(input_list: Input[List[Artifact]]): + python_component(input_list=input_list) + + self.assertEqual( + pipeline_component.pipeline_spec.root.input_definitions + .artifacts['input_list'].is_artifact_list, True) + self.assertEqual( + pipeline_component.pipeline_spec.components['comp-python-component'] + .input_definitions.artifacts['input_list'].is_artifact_list, True) + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=pipeline_component, package_path=output_yaml) + loaded_component = components.load_component_from_file(output_yaml) + + self.assertEqual( + loaded_component.pipeline_spec.root.input_definitions + .artifacts['input_list'].is_artifact_list, True) + self.assertEqual( + loaded_component.pipeline_spec.components['comp-python-component'] + .input_definitions.artifacts['input_list'].is_artifact_list, True) + + if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py index 7468fd160f7..c1eee7b1378 100644 --- a/sdk/python/kfp/compiler/pipeline_spec_builder.py +++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py @@ -380,6 +380,8 @@ def _build_component_spec_from_component_spec_structure( input_name].artifact_type.CopyFrom( type_utils.bundled_artifact_to_artifact_proto( input_spec.type)) + component_spec.input_definitions.artifacts[ + input_name].is_artifact_list = input_spec.is_artifact_list if input_spec.optional: component_spec.input_definitions.artifacts[ input_name].is_optional = True @@ -395,6 +397,8 @@ def _build_component_spec_from_component_spec_structure( output_name].artifact_type.CopyFrom( type_utils.bundled_artifact_to_artifact_proto( output_spec.type)) + component_spec.output_definitions.artifacts[ + output_name].is_artifact_list = output_spec.is_artifact_list return component_spec @@ -413,7 +417,9 @@ def _connect_dag_outputs( """ if isinstance(output_channel, pipeline_channel.PipelineArtifactChannel): if output_name not in component_spec.output_definitions.artifacts: - raise ValueError(f'DAG output not defined: {output_name}.') + raise ValueError( + f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.' + ) component_spec.dag.outputs.artifacts[ output_name].artifact_selectors.append( pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec( @@ -422,7 +428,9 @@ def _connect_dag_outputs( )) elif isinstance(output_channel, pipeline_channel.PipelineParameterChannel): if output_name not in component_spec.output_definitions.parameters: - raise ValueError(f'Pipeline output not defined: {output_name}.') + raise ValueError( + f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.' + ) component_spec.dag.outputs.parameters[ output_name].value_from_parameter.producer_subtask = output_channel.task_name component_spec.dag.outputs.parameters[ diff --git a/sdk/python/kfp/components/component_factory.py b/sdk/python/kfp/components/component_factory.py index 070de0014f0..46d0c263271 100644 --- a/sdk/python/kfp/components/component_factory.py +++ b/sdk/python/kfp/components/component_factory.py @@ -17,17 +17,17 @@ import pathlib import re import textwrap -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Type, Union import warnings import docstring_parser from kfp.components import container_component +from kfp.components import container_component_artifact_channel from kfp.components import graph_component from kfp.components import placeholders from kfp.components import python_component from kfp.components import structures -from kfp.components.container_component_artifact_channel import \ - ContainerComponentArtifactChannel +from kfp.components.types import artifact_types from kfp.components.types import custom_artifact_types from kfp.components.types import type_annotations from kfp.components.types import type_utils @@ -477,6 +477,38 @@ def create_component_from_func( component_spec=component_spec, python_func=func) +def make_input_for_parameterized_container_component_function( + name: str, annotation: Union[Type[List[artifact_types.Artifact]], + Type[artifact_types.Artifact]] +) -> Union[placeholders.Placeholder, container_component_artifact_channel + .ContainerComponentArtifactChannel]: + if type_annotations.is_input_artifact(annotation): + + if type_annotations.is_list_of_artifacts(annotation.__origin__): + return placeholders.InputListOfArtifactsPlaceholder(name) + else: + return container_component_artifact_channel.ContainerComponentArtifactChannel( + io_type='input', var_name=name) + + elif type_annotations.is_output_artifact(annotation): + + if type_annotations.is_list_of_artifacts(annotation.__origin__): + raise ValueError( + 'Outputting a list of artifacts from a Custom Container Component is not currently supported.' + ) + else: + return container_component_artifact_channel.ContainerComponentArtifactChannel( + io_type='output', var_name=name) + + elif isinstance( + annotation, + (type_annotations.OutputAnnotation, type_annotations.OutputPath)): + return placeholders.OutputParameterPlaceholder(name) + + else: + return placeholders.InputValuePlaceholder(name) + + def create_container_component_from_func( func: Callable) -> container_component.ContainerComponent: """Implementation for the @container_component decorator. @@ -486,27 +518,15 @@ def create_container_component_from_func( """ component_spec = extract_component_interface(func, containerized=True) - arg_list = [] signature = inspect.signature(func) parameters = list(signature.parameters.values()) + arg_list = [] for parameter in parameters: parameter_type = type_annotations.maybe_strip_optional_from_annotation( parameter.annotation) - io_name = parameter.name - if type_annotations.is_input_artifact(parameter_type): - arg_list.append( - ContainerComponentArtifactChannel( - io_type='input', var_name=io_name)) - elif type_annotations.is_output_artifact(parameter_type): - arg_list.append( - ContainerComponentArtifactChannel( - io_type='output', var_name=io_name)) - elif isinstance( - parameter_type, - (type_annotations.OutputAnnotation, type_annotations.OutputPath)): - arg_list.append(placeholders.OutputParameterPlaceholder(io_name)) - else: # parameter is an input value - arg_list.append(placeholders.InputValuePlaceholder(io_name)) + arg_list.append( + make_input_for_parameterized_container_component_function( + parameter.name, parameter_type)) container_spec = func(*arg_list) container_spec_implementation = structures.ContainerSpecImplementation.from_container_spec( diff --git a/sdk/python/kfp/components/component_factory_test.py b/sdk/python/kfp/components/component_factory_test.py index 5251c199cab..92a67373eac 100644 --- a/sdk/python/kfp/components/component_factory_test.py +++ b/sdk/python/kfp/components/component_factory_test.py @@ -132,20 +132,6 @@ def comp(i: Input[List[Model]]): is_artifact_list=True) }) - def test_pipeline_with_named_tuple_fn(self): - from typing import NamedTuple - - def comp( - i: Input[List[Model]] - ) -> NamedTuple('outputs', [('output_list', List[Artifact])]): - ... - - with self.assertRaisesRegex( - ValueError, - r'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `output_list`\.' - ): - component_factory.extract_component_interface(comp) - class TestOutputListsOfArtifactsTemporarilyBlocked(unittest.TestCase): @@ -176,6 +162,20 @@ def test_pipeline(self): def comp() -> List[Artifact]: ... + def test_pipeline_with_named_tuple_fn(self): + from typing import NamedTuple + + def comp( + i: Input[List[Model]] + ) -> NamedTuple('outputs', [('output_list', List[Artifact])]): + ... + + with self.assertRaisesRegex( + ValueError, + r'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `output_list`\.' + ): + component_factory.extract_component_interface(comp) + if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/components/container_component_artifact_channel_test.py b/sdk/python/kfp/components/container_component_artifact_channel_test.py index d9568b4f64a..35eaeafdd6a 100644 --- a/sdk/python/kfp/components/container_component_artifact_channel_test.py +++ b/sdk/python/kfp/components/container_component_artifact_channel_test.py @@ -14,16 +14,16 @@ import unittest -from kfp.components import component_factory +from kfp.components import container_component_artifact_channel from kfp.components import placeholders class TestContainerComponentArtifactChannel(unittest.TestCase): def test_correct_placeholder_and_attribute_error(self): - in_channel = component_factory.ContainerComponentArtifactChannel( + in_channel = container_component_artifact_channel.ContainerComponentArtifactChannel( 'input', 'my_dataset') - out_channel = component_factory.ContainerComponentArtifactChannel( + out_channel = container_component_artifact_channel.ContainerComponentArtifactChannel( 'output', 'my_result') self.assertEqual( in_channel.uri._to_string(), diff --git a/sdk/python/kfp/components/placeholders.py b/sdk/python/kfp/components/placeholders.py index 9c6bf5db56f..bfb5d72833d 100644 --- a/sdk/python/kfp/components/placeholders.py +++ b/sdk/python/kfp/components/placeholders.py @@ -57,6 +57,28 @@ def _to_string(self) -> str: return f"{{{{$.inputs.parameters['{self.input_name}']}}}}" +class InputListOfArtifactsPlaceholder(Placeholder): + + def __init__(self, input_name: str) -> None: + self.input_name = input_name + + def _to_string(self) -> str: + return f"{{{{$.inputs.artifacts['{self.input_name}']}}}}" + + def __getattribute__(self, name: str) -> Any: + if name in {'name', 'uri', 'metadata', 'path'}: + raise AttributeError( + f'Cannot access an attribute on a list of artifacts in a Custom Container Component. Found reference to attribute {name!r} on {self.input_name!r}. Please pass the whole list of artifacts only.' + ) + else: + return object.__getattribute__(self, name) + + def __getitem__(self, k: int) -> None: + raise KeyError( + f'Cannot access individual artifacts in a list of artifacts. Found access to element {k} on {self.input_name!r}. Please pass the whole list of artifacts only.' + ) + + class InputPathPlaceholder(Placeholder): def __init__(self, input_name: str) -> None: @@ -270,7 +292,8 @@ def __str__(self) -> str: _CONTAINER_PLACEHOLDERS = (IfPresentPlaceholder, ConcatPlaceholder) PRIMITIVE_INPUT_PLACEHOLDERS = (InputValuePlaceholder, InputPathPlaceholder, - InputUriPlaceholder, InputMetadataPlaceholder) + InputUriPlaceholder, InputMetadataPlaceholder, + InputListOfArtifactsPlaceholder) PRIMITIVE_OUTPUT_PLACEHOLDERS = (OutputParameterPlaceholder, OutputPathPlaceholder, OutputUriPlaceholder, OutputMetadataPlaceholder) diff --git a/sdk/python/kfp/components/placeholders_test.py b/sdk/python/kfp/components/placeholders_test.py index 89b2b0db6a6..11638c45f4f 100644 --- a/sdk/python/kfp/components/placeholders_test.py +++ b/sdk/python/kfp/components/placeholders_test.py @@ -14,13 +14,15 @@ """Contains tests for kfp.components.placeholders.""" import os import tempfile -from typing import Any +from typing import Any, List from absl.testing import parameterized from kfp import compiler from kfp import dsl from kfp.components import placeholders from kfp.dsl import Artifact +from kfp.dsl import Dataset +from kfp.dsl import Input from kfp.dsl import Output @@ -379,6 +381,81 @@ def test_valid_then_but_invalid_else(self): ]) +class TestListOfArtifactsInContainerComponentPlaceholders( + parameterized.TestCase): + + def test_compile_component1(self): + + @dsl.container_component + def comp(input_list: Input[List[Artifact]]): + return dsl.ContainerSpec( + image='alpine', command=[input_list], args=[input_list]) + + self.assertEqual( + comp.pipeline_spec.deployment_spec['executors']['exec-comp'] + ['container']['command'][0], "{{$.inputs.artifacts['input_list']}}") + self.assertEqual( + comp.pipeline_spec.deployment_spec['executors']['exec-comp'] + ['container']['args'][0], "{{$.inputs.artifacts['input_list']}}") + + def test_compile_component2(self): + + @dsl.container_component + def comp(new_name: Input[List[Dataset]]): + return dsl.ContainerSpec( + image='alpine', command=[new_name], args=[new_name]) + + self.assertEqual( + comp.pipeline_spec.deployment_spec['executors']['exec-comp'] + ['container']['command'][0], "{{$.inputs.artifacts['new_name']}}") + self.assertEqual( + comp.pipeline_spec.deployment_spec['executors']['exec-comp'] + ['container']['args'][0], "{{$.inputs.artifacts['new_name']}}") + + def test_cannot_access_name(self): + with self.assertRaisesRegex(AttributeError, + 'Cannot access an attribute'): + + @dsl.container_component + def comp(new_name: Input[List[Dataset]]): + return dsl.ContainerSpec( + image='alpine', command=[new_name.name]) + + def test_cannot_access_uri(self): + with self.assertRaisesRegex(AttributeError, + 'Cannot access an attribute'): + + @dsl.container_component + def comp(new_name: Input[List[Dataset]]): + return dsl.ContainerSpec(image='alpine', command=[new_name.uri]) + + def test_cannot_access_metadata(self): + with self.assertRaisesRegex(AttributeError, + 'Cannot access an attribute'): + + @dsl.container_component + def comp(new_name: Input[List[Dataset]]): + return dsl.ContainerSpec( + image='alpine', command=[new_name.metadata]) + + def test_cannot_access_path(self): + with self.assertRaisesRegex(AttributeError, + 'Cannot access an attribute'): + + @dsl.container_component + def comp(new_name: Input[List[Dataset]]): + return dsl.ContainerSpec( + image='alpine', command=[new_name.path]) + + def test_cannot_access_individual_artifact(self): + with self.assertRaisesRegex(KeyError, + 'Cannot access individual artifacts'): + + @dsl.container_component + def comp(new_name: Input[List[Dataset]]): + return dsl.ContainerSpec(image='alpine', command=[new_name[0]]) + + class TestConvertCommandLineElementToStringOrStruct(parameterized.TestCase): @parameterized.parameters(['a', 'word', 1]) diff --git a/sdk/python/kfp/components/structures.py b/sdk/python/kfp/components/structures.py index 95a5be6b45e..f01b717f018 100644 --- a/sdk/python/kfp/components/structures.py +++ b/sdk/python/kfp/components/structures.py @@ -91,10 +91,13 @@ def from_ir_component_inputs_dict( # TODO: would be better to extract these fields from the proto # message, as False default would be preserved optional = ir_component_inputs_dict.get('isOptional', False) + is_artifact_list = ir_component_inputs_dict.get( + 'isArtifactList', False) return InputSpec( type=type_utils.create_bundled_artifact_type( type_, schema_version), - optional=optional) + optional=optional, + is_artifact_list=is_artifact_list) def __eq__(self, other: Any) -> bool: """Equality comparison for InputSpec. Robust to different type @@ -177,9 +180,12 @@ def from_ir_component_outputs_dict( type_ = ir_component_outputs_dict['artifactType']['schemaTitle'] schema_version = ir_component_outputs_dict['artifactType'][ 'schemaVersion'] + is_artifact_list = ir_component_outputs_dict.get( + 'isArtifactList', False) return OutputSpec( type=type_utils.create_bundled_artifact_type( - type_, schema_version)) + type_, schema_version), + is_artifact_list=is_artifact_list) def __eq__(self, other: Any) -> bool: """Equality comparison for OutputSpec. Robust to different type