Skip to content

Commit

Permalink
support list of artifact input placeholders
Browse files Browse the repository at this point in the history
  • Loading branch information
connor-mccarthy committed Feb 14, 2023
1 parent 8552226 commit 0d5533b
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 24 deletions.
58 changes: 39 additions & 19 deletions sdk/python/kfp/components/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
import pathlib
import re
import textwrap
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional, Tuple, Type, Union
import warnings

import docstring_parser
from kfp.components import container_component
from kfp.components import container_component_artifact_channel
from kfp.components import graph_component
from kfp.components import placeholders
from kfp.components import python_component
from kfp.components import structures
from kfp.components.container_component_artifact_channel import \
ContainerComponentArtifactChannel
from kfp.components.types import artifact_types
from kfp.components.types import custom_artifact_types
from kfp.components.types import type_annotations
from kfp.components.types import type_utils
Expand Down Expand Up @@ -477,6 +477,38 @@ def create_component_from_func(
component_spec=component_spec, python_func=func)


def make_input_for_parameterized_container_component_function(
name: str, annotation: Union[Type[List[artifact_types.Artifact]],
Type[artifact_types.Artifact]]
) -> Union[placeholders.Placeholder, container_component_artifact_channel
.ContainerComponentArtifactChannel]:
if type_annotations.is_input_artifact(annotation):

if type_annotations.is_list_of_artifacts(annotation.__origin__):
return placeholders.InputListOfArtifactsPlaceholder(name)
else:
return container_component_artifact_channel.ContainerComponentArtifactChannel(
io_type='input', var_name=name)

elif type_annotations.is_output_artifact(annotation):

if type_annotations.is_list_of_artifacts(annotation.__origin__):
raise ValueError(
'Outputting a list of artifacts from a Custom Container Component is not currently supported.'
)
else:
return container_component_artifact_channel.ContainerComponentArtifactChannel(
io_type='output', var_name=name)

elif isinstance(
annotation,
(type_annotations.OutputAnnotation, type_annotations.OutputPath)):
return placeholders.OutputParameterPlaceholder(name)

else:
return placeholders.InputValuePlaceholder(name)


def create_container_component_from_func(
func: Callable) -> container_component.ContainerComponent:
"""Implementation for the @container_component decorator.
Expand All @@ -486,27 +518,15 @@ def create_container_component_from_func(
"""

component_spec = extract_component_interface(func, containerized=True)
arg_list = []
signature = inspect.signature(func)
parameters = list(signature.parameters.values())
arg_list = []
for parameter in parameters:
parameter_type = type_annotations.maybe_strip_optional_from_annotation(
parameter.annotation)
io_name = parameter.name
if type_annotations.is_input_artifact(parameter_type):
arg_list.append(
ContainerComponentArtifactChannel(
io_type='input', var_name=io_name))
elif type_annotations.is_output_artifact(parameter_type):
arg_list.append(
ContainerComponentArtifactChannel(
io_type='output', var_name=io_name))
elif isinstance(
parameter_type,
(type_annotations.OutputAnnotation, type_annotations.OutputPath)):
arg_list.append(placeholders.OutputParameterPlaceholder(io_name))
else: # parameter is an input value
arg_list.append(placeholders.InputValuePlaceholder(io_name))
arg_list.append(
make_input_for_parameterized_container_component_function(
parameter.name, parameter_type))

container_spec = func(*arg_list)
container_spec_implementation = structures.ContainerSpecImplementation.from_container_spec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

import unittest

from kfp.components import component_factory
from kfp.components import container_component_artifact_channel
from kfp.components import placeholders


class TestContainerComponentArtifactChannel(unittest.TestCase):

def test_correct_placeholder_and_attribute_error(self):
in_channel = component_factory.ContainerComponentArtifactChannel(
in_channel = container_component_artifact_channel.ContainerComponentArtifactChannel(
'input', 'my_dataset')
out_channel = component_factory.ContainerComponentArtifactChannel(
out_channel = container_component_artifact_channel.ContainerComponentArtifactChannel(
'output', 'my_result')
self.assertEqual(
in_channel.uri._to_string(),
Expand Down
25 changes: 24 additions & 1 deletion sdk/python/kfp/components/placeholders.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ def _to_string(self) -> str:
return f"{{{{$.inputs.parameters['{self.input_name}']}}}}"


class InputListOfArtifactsPlaceholder(Placeholder):

def __init__(self, input_name: str) -> None:
self.input_name = input_name

def _to_string(self) -> str:
return f"{{{{$.inputs.artifacts['{self.input_name}']}}}}"

def __getattribute__(self, name: str) -> Any:
if name in {'name', 'uri', 'metadata', 'path'}:
raise AttributeError(
f'Cannot access an attribute on a list of artifacts in a Custom Container Component. Found refence to attribute {name!r} on {self.input_name!r}. Please pass the whole list of artifacts only.'
)
else:
return object.__getattribute__(self, name)

def __getitem__(self, k: int) -> None:
raise KeyError(
f'Cannot access individual artifacts in a list of artifacts. Found access to element {k} on {self.input_name!r}. Please pass the whole list of artifacts only.'
)


class InputPathPlaceholder(Placeholder):

def __init__(self, input_name: str) -> None:
Expand Down Expand Up @@ -270,7 +292,8 @@ def __str__(self) -> str:

_CONTAINER_PLACEHOLDERS = (IfPresentPlaceholder, ConcatPlaceholder)
PRIMITIVE_INPUT_PLACEHOLDERS = (InputValuePlaceholder, InputPathPlaceholder,
InputUriPlaceholder, InputMetadataPlaceholder)
InputUriPlaceholder, InputMetadataPlaceholder,
InputListOfArtifactsPlaceholder)
PRIMITIVE_OUTPUT_PLACEHOLDERS = (OutputParameterPlaceholder,
OutputPathPlaceholder, OutputUriPlaceholder,
OutputMetadataPlaceholder)
Expand Down
79 changes: 78 additions & 1 deletion sdk/python/kfp/components/placeholders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
"""Contains tests for kfp.components.placeholders."""
import os
import tempfile
from typing import Any
from typing import Any, List

from absl.testing import parameterized
from kfp import compiler
from kfp import dsl
from kfp.components import placeholders
from kfp.dsl import Artifact
from kfp.dsl import Dataset
from kfp.dsl import Input
from kfp.dsl import Output


Expand Down Expand Up @@ -379,6 +381,81 @@ def test_valid_then_but_invalid_else(self):
])


class TestListOfArtifactsInContainerComponentPlaceholders(
parameterized.TestCase):

def test_compile_component1(self):

@dsl.container_component
def comp(input_list: Input[List[Artifact]]):
return dsl.ContainerSpec(
image='alpine', command=[input_list], args=[input_list])

self.assertEqual(
comp.pipeline_spec.deployment_spec['executors']['exec-comp']
['container']['command'][0], "{{$.inputs.artifacts['input_list']}}")
self.assertEqual(
comp.pipeline_spec.deployment_spec['executors']['exec-comp']
['container']['args'][0], "{{$.inputs.artifacts['input_list']}}")

def test_compile_component2(self):

@dsl.container_component
def comp(new_name: Input[List[Dataset]]):
return dsl.ContainerSpec(
image='alpine', command=[new_name], args=[new_name])

self.assertEqual(
comp.pipeline_spec.deployment_spec['executors']['exec-comp']
['container']['command'][0], "{{$.inputs.artifacts['new_name']}}")
self.assertEqual(
comp.pipeline_spec.deployment_spec['executors']['exec-comp']
['container']['args'][0], "{{$.inputs.artifacts['new_name']}}")

def test_cannot_access_name(self):
with self.assertRaisesRegex(AttributeError,
'Cannot access an attribute'):

@dsl.container_component
def comp(new_name: Input[List[Dataset]]):
return dsl.ContainerSpec(
image='alpine', command=[new_name.name])

def test_cannot_access_uri(self):
with self.assertRaisesRegex(AttributeError,
'Cannot access an attribute'):

@dsl.container_component
def comp(new_name: Input[List[Dataset]]):
return dsl.ContainerSpec(image='alpine', command=[new_name.uri])

def test_cannot_access_metadata(self):
with self.assertRaisesRegex(AttributeError,
'Cannot access an attribute'):

@dsl.container_component
def comp(new_name: Input[List[Dataset]]):
return dsl.ContainerSpec(
image='alpine', command=[new_name.metadata])

def test_cannot_access_path(self):
with self.assertRaisesRegex(AttributeError,
'Cannot access an attribute'):

@dsl.container_component
def comp(new_name: Input[List[Dataset]]):
return dsl.ContainerSpec(
image='alpine', command=[new_name.path])

def test_cannot_access_individual_artifact(self):
with self.assertRaisesRegex(KeyError,
'Cannot access individual artifacts'):

@dsl.container_component
def comp(new_name: Input[List[Dataset]]):
return dsl.ContainerSpec(image='alpine', command=[new_name[0]])


class TestConvertCommandLineElementToStringOrStruct(parameterized.TestCase):

@parameterized.parameters(['a', 'word', 1])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List

from kfp import dsl
from kfp.dsl import Artifact
from kfp.dsl import Input


@dsl.container_component
def comp_with_artifact_list(input_list: Input[List[Artifact]]):
return dsl.ContainerSpec(
image='alpine', command=[input_list], args=[input_list])


if __name__ == '__main__':
from kfp import compiler
compiler.Compiler().compile(
pipeline_func=comp_with_artifact_list,
package_path=__file__.replace('.py', '.yaml'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# PIPELINE DEFINITION
# Name: comp-with-artifact-list
# Inputs:
# input_list: system.Artifact
components:
comp-comp-with-artifact-list:
executorLabel: exec-comp-with-artifact-list
inputDefinitions:
artifacts:
input_list:
artifactType:
schemaTitle: system.Artifact
schemaVersion: 0.0.1
deploymentSpec:
executors:
exec-comp-with-artifact-list:
container:
args:
- '{{$.inputs.artifacts[''input_list'']}}'
command:
- '{{$.inputs.artifacts[''input_list'']}}'
image: alpine
pipelineInfo:
name: comp-with-artifact-list
root:
dag:
tasks:
comp-with-artifact-list:
cachingOptions:
enableCache: true
componentRef:
name: comp-comp-with-artifact-list
inputs:
artifacts:
input_list:
componentInputArtifact: input_list
taskInfo:
name: comp-with-artifact-list
inputDefinitions:
artifacts:
input_list:
artifactType:
schemaTitle: system.Artifact
schemaVersion: 0.0.1
schemaVersion: 2.1.0
sdkVersion: kfp-2.0.0-beta.8
3 changes: 3 additions & 0 deletions sdk/python/test_data/test_data_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ components:
- module: container_with_placeholder_in_fstring
name: container_with_placeholder_in_fstring
execute: false
- module: container_component_with_list_of_artifacts
name: comp_with_artifact_list
execute: false
v1_components:
test_data_dir: sdk/python/test_data/v1_component_yaml
read: true
Expand Down

0 comments on commit 0d5533b

Please sign in to comment.