Skip to content

Commit

Permalink
support list of artifact input placeholders (#8484)
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccarthy authored Feb 16, 2023
1 parent 03b7752 commit 4175260
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 44 deletions.
92 changes: 90 additions & 2 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,8 @@ def my_pipeline():
pipeline_spec['deploymentSpec']['executors'])

def test_pipeline_with_invalid_output(self):
with self.assertRaisesRegex(ValueError,
'Pipeline output not defined: msg1'):
with self.assertRaisesRegex(
ValueError, r'Pipeline or component output not defined: msg1'):

@dsl.component
def print_op(msg: str) -> str:
Expand Down Expand Up @@ -3092,5 +3092,93 @@ def comp(message: str) -> str:
comp.pipeline_spec.root.output_definitions.parameters)


class TestListOfArtifactsInterfaceCompileAndLoad(unittest.TestCase):

def test_python_component(self):

@dsl.component
def python_component(input_list: Input[List[Artifact]]):
pass

self.assertEqual(
python_component.pipeline_spec.root.input_definitions
.artifacts['input_list'].is_artifact_list, True)
self.assertEqual(
python_component.pipeline_spec.components['comp-python-component']
.input_definitions.artifacts['input_list'].is_artifact_list, True)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=python_component, package_path=output_yaml)
loaded_component = components.load_component_from_file(output_yaml)

self.assertEqual(
loaded_component.pipeline_spec.root.input_definitions
.artifacts['input_list'].is_artifact_list, True)
self.assertEqual(
loaded_component.pipeline_spec.components['comp-python-component']
.input_definitions.artifacts['input_list'].is_artifact_list, True)

def test_container_component(self):

@dsl.container_component
def container_component(input_list: Input[List[Artifact]]):
return dsl.ContainerSpec(
image='alpine', command=['echo'], args=['hello world'])

self.assertEqual(
container_component.pipeline_spec.root.input_definitions
.artifacts['input_list'].is_artifact_list, True)
self.assertEqual(
container_component.pipeline_spec
.components['comp-container-component'].input_definitions
.artifacts['input_list'].is_artifact_list, True)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=container_component, package_path=output_yaml)
loaded_component = components.load_component_from_file(output_yaml)

self.assertEqual(
loaded_component.pipeline_spec.root.input_definitions
.artifacts['input_list'].is_artifact_list, True)
self.assertEqual(
loaded_component.pipeline_spec
.components['comp-container-component'].input_definitions
.artifacts['input_list'].is_artifact_list, True)

def test_pipeline(self):

@dsl.component
def python_component(input_list: Input[List[Artifact]]):
pass

@dsl.pipeline
def pipeline_component(input_list: Input[List[Artifact]]):
python_component(input_list=input_list)

self.assertEqual(
pipeline_component.pipeline_spec.root.input_definitions
.artifacts['input_list'].is_artifact_list, True)
self.assertEqual(
pipeline_component.pipeline_spec.components['comp-python-component']
.input_definitions.artifacts['input_list'].is_artifact_list, True)

with tempfile.TemporaryDirectory() as tempdir:
output_yaml = os.path.join(tempdir, 'pipeline.yaml')
compiler.Compiler().compile(
pipeline_func=pipeline_component, package_path=output_yaml)
loaded_component = components.load_component_from_file(output_yaml)

self.assertEqual(
loaded_component.pipeline_spec.root.input_definitions
.artifacts['input_list'].is_artifact_list, True)
self.assertEqual(
loaded_component.pipeline_spec.components['comp-python-component']
.input_definitions.artifacts['input_list'].is_artifact_list, True)


if __name__ == '__main__':
unittest.main()
12 changes: 10 additions & 2 deletions sdk/python/kfp/compiler/pipeline_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def _build_component_spec_from_component_spec_structure(
input_name].artifact_type.CopyFrom(
type_utils.bundled_artifact_to_artifact_proto(
input_spec.type))
component_spec.input_definitions.artifacts[
input_name].is_artifact_list = input_spec.is_artifact_list
if input_spec.optional:
component_spec.input_definitions.artifacts[
input_name].is_optional = True
Expand All @@ -395,6 +397,8 @@ def _build_component_spec_from_component_spec_structure(
output_name].artifact_type.CopyFrom(
type_utils.bundled_artifact_to_artifact_proto(
output_spec.type))
component_spec.output_definitions.artifacts[
output_name].is_artifact_list = output_spec.is_artifact_list

return component_spec

Expand All @@ -413,7 +417,9 @@ def _connect_dag_outputs(
"""
if isinstance(output_channel, pipeline_channel.PipelineArtifactChannel):
if output_name not in component_spec.output_definitions.artifacts:
raise ValueError(f'DAG output not defined: {output_name}.')
raise ValueError(
f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.'
)
component_spec.dag.outputs.artifacts[
output_name].artifact_selectors.append(
pipeline_spec_pb2.DagOutputsSpec.ArtifactSelectorSpec(
Expand All @@ -422,7 +428,9 @@ def _connect_dag_outputs(
))
elif isinstance(output_channel, pipeline_channel.PipelineParameterChannel):
if output_name not in component_spec.output_definitions.parameters:
raise ValueError(f'Pipeline output not defined: {output_name}.')
raise ValueError(
f'Pipeline or component output not defined: {output_name}. You may be missing a type annotation.'
)
component_spec.dag.outputs.parameters[
output_name].value_from_parameter.producer_subtask = output_channel.task_name
component_spec.dag.outputs.parameters[
Expand Down
58 changes: 39 additions & 19 deletions sdk/python/kfp/components/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
import pathlib
import re
import textwrap
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Type, Union
import warnings

import docstring_parser
from kfp.components import container_component
from kfp.components import container_component_artifact_channel
from kfp.components import graph_component
from kfp.components import placeholders
from kfp.components import python_component
from kfp.components import structures
from kfp.components.container_component_artifact_channel import \
ContainerComponentArtifactChannel
from kfp.components.types import artifact_types
from kfp.components.types import custom_artifact_types
from kfp.components.types import type_annotations
from kfp.components.types import type_utils
Expand Down Expand Up @@ -477,6 +477,38 @@ def create_component_from_func(
component_spec=component_spec, python_func=func)


def make_input_for_parameterized_container_component_function(
name: str, annotation: Union[Type[List[artifact_types.Artifact]],
Type[artifact_types.Artifact]]
) -> Union[placeholders.Placeholder, container_component_artifact_channel
.ContainerComponentArtifactChannel]:
if type_annotations.is_input_artifact(annotation):

if type_annotations.is_list_of_artifacts(annotation.__origin__):
return placeholders.InputListOfArtifactsPlaceholder(name)
else:
return container_component_artifact_channel.ContainerComponentArtifactChannel(
io_type='input', var_name=name)

elif type_annotations.is_output_artifact(annotation):

if type_annotations.is_list_of_artifacts(annotation.__origin__):
raise ValueError(
'Outputting a list of artifacts from a Custom Container Component is not currently supported.'
)
else:
return container_component_artifact_channel.ContainerComponentArtifactChannel(
io_type='output', var_name=name)

elif isinstance(
annotation,
(type_annotations.OutputAnnotation, type_annotations.OutputPath)):
return placeholders.OutputParameterPlaceholder(name)

else:
return placeholders.InputValuePlaceholder(name)


def create_container_component_from_func(
func: Callable) -> container_component.ContainerComponent:
"""Implementation for the @container_component decorator.
Expand All @@ -486,27 +518,15 @@ def create_container_component_from_func(
"""

component_spec = extract_component_interface(func, containerized=True)
arg_list = []
signature = inspect.signature(func)
parameters = list(signature.parameters.values())
arg_list = []
for parameter in parameters:
parameter_type = type_annotations.maybe_strip_optional_from_annotation(
parameter.annotation)
io_name = parameter.name
if type_annotations.is_input_artifact(parameter_type):
arg_list.append(
ContainerComponentArtifactChannel(
io_type='input', var_name=io_name))
elif type_annotations.is_output_artifact(parameter_type):
arg_list.append(
ContainerComponentArtifactChannel(
io_type='output', var_name=io_name))
elif isinstance(
parameter_type,
(type_annotations.OutputAnnotation, type_annotations.OutputPath)):
arg_list.append(placeholders.OutputParameterPlaceholder(io_name))
else: # parameter is an input value
arg_list.append(placeholders.InputValuePlaceholder(io_name))
arg_list.append(
make_input_for_parameterized_container_component_function(
parameter.name, parameter_type))

container_spec = func(*arg_list)
container_spec_implementation = structures.ContainerSpecImplementation.from_container_spec(
Expand Down
28 changes: 14 additions & 14 deletions sdk/python/kfp/components/component_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,6 @@ def comp(i: Input[List[Model]]):
is_artifact_list=True)
})

def test_pipeline_with_named_tuple_fn(self):
from typing import NamedTuple

def comp(
i: Input[List[Model]]
) -> NamedTuple('outputs', [('output_list', List[Artifact])]):
...

with self.assertRaisesRegex(
ValueError,
r'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `output_list`\.'
):
component_factory.extract_component_interface(comp)


class TestOutputListsOfArtifactsTemporarilyBlocked(unittest.TestCase):

Expand Down Expand Up @@ -176,6 +162,20 @@ def test_pipeline(self):
def comp() -> List[Artifact]:
...

def test_pipeline_with_named_tuple_fn(self):
from typing import NamedTuple

def comp(
i: Input[List[Model]]
) -> NamedTuple('outputs', [('output_list', List[Artifact])]):
...

with self.assertRaisesRegex(
ValueError,
r'Cannot use output lists of artifacts in NamedTuple return annotations. Got output list of artifacts annotation for NamedTuple field `output_list`\.'
):
component_factory.extract_component_interface(comp)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

import unittest

from kfp.components import component_factory
from kfp.components import container_component_artifact_channel
from kfp.components import placeholders


class TestContainerComponentArtifactChannel(unittest.TestCase):

def test_correct_placeholder_and_attribute_error(self):
in_channel = component_factory.ContainerComponentArtifactChannel(
in_channel = container_component_artifact_channel.ContainerComponentArtifactChannel(
'input', 'my_dataset')
out_channel = component_factory.ContainerComponentArtifactChannel(
out_channel = container_component_artifact_channel.ContainerComponentArtifactChannel(
'output', 'my_result')
self.assertEqual(
in_channel.uri._to_string(),
Expand Down
25 changes: 24 additions & 1 deletion sdk/python/kfp/components/placeholders.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ def _to_string(self) -> str:
return f"{{{{$.inputs.parameters['{self.input_name}']}}}}"


class InputListOfArtifactsPlaceholder(Placeholder):

def __init__(self, input_name: str) -> None:
self.input_name = input_name

def _to_string(self) -> str:
return f"{{{{$.inputs.artifacts['{self.input_name}']}}}}"

def __getattribute__(self, name: str) -> Any:
if name in {'name', 'uri', 'metadata', 'path'}:
raise AttributeError(
f'Cannot access an attribute on a list of artifacts in a Custom Container Component. Found reference to attribute {name!r} on {self.input_name!r}. Please pass the whole list of artifacts only.'
)
else:
return object.__getattribute__(self, name)

def __getitem__(self, k: int) -> None:
raise KeyError(
f'Cannot access individual artifacts in a list of artifacts. Found access to element {k} on {self.input_name!r}. Please pass the whole list of artifacts only.'
)


class InputPathPlaceholder(Placeholder):

def __init__(self, input_name: str) -> None:
Expand Down Expand Up @@ -270,7 +292,8 @@ def __str__(self) -> str:

_CONTAINER_PLACEHOLDERS = (IfPresentPlaceholder, ConcatPlaceholder)
PRIMITIVE_INPUT_PLACEHOLDERS = (InputValuePlaceholder, InputPathPlaceholder,
InputUriPlaceholder, InputMetadataPlaceholder)
InputUriPlaceholder, InputMetadataPlaceholder,
InputListOfArtifactsPlaceholder)
PRIMITIVE_OUTPUT_PLACEHOLDERS = (OutputParameterPlaceholder,
OutputPathPlaceholder, OutputUriPlaceholder,
OutputMetadataPlaceholder)
Expand Down
Loading

0 comments on commit 4175260

Please sign in to comment.