Skip to content

Commit

Permalink
feat(sdk): add runtime logic for custom artifact types (support for c…
Browse files Browse the repository at this point in the history
…ustom artifact types pt. 3) (kubeflow#8233)

* add runtime artifact instance creation logic

* refactor executor

* add executor tests

* add custom artifact type import handling and tests

* fix artifact class construction

* fix custom artifact type in tests

* add typing extensions dependency for all python versions

* use mock google namespace artifact for tests

* remove print statement

* update google artifact golden snapshot

* resolve some review feedback

* remove handling for OutputPath and InputPath custom artifact types; update function names and tests

* clarify named tuple tests

* update executor tests

* add artifact return and named tuple support; refactor; clean tests

* implement review feedback; clean up artifact names

* move test method
  • Loading branch information
connor-mccarthy authored and jlyaoyuli committed Jan 5, 2023
1 parent 04f7cef commit 158f3e1
Show file tree
Hide file tree
Showing 14 changed files with 1,352 additions and 356 deletions.
11 changes: 6 additions & 5 deletions sdk/python/kfp/components/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from kfp.components import structures
from kfp.components.container_component_artifact_channel import \
ContainerComponentArtifactChannel
from kfp.components.types import custom_artifact_types
from kfp.components.types import type_annotations
from kfp.components.types import type_utils

Expand Down Expand Up @@ -171,7 +172,7 @@ def extract_component_interface(
# parameter_type is type_annotations.Artifact or one of its subclasses.
parameter_type = type_annotations.get_io_artifact_class(
parameter_type)
if not type_annotations.is_artifact(parameter_type):
if not type_annotations.is_artifact_class(parameter_type):
raise ValueError(
'Input[T] and Output[T] are only supported when T is a '
'subclass of Artifact. Found `{} with type {}`'.format(
Expand Down Expand Up @@ -203,7 +204,7 @@ def extract_component_interface(
]:
io_name = _maybe_make_unique(io_name, output_names)
output_names.add(io_name)
if type_annotations.is_artifact(parameter_type):
if type_annotations.is_artifact_class(parameter_type):
schema_version = parameter_type.schema_version
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
Expand All @@ -214,7 +215,7 @@ def extract_component_interface(
else:
io_name = _maybe_make_unique(io_name, input_names)
input_names.add(io_name)
if type_annotations.is_artifact(parameter_type):
if type_annotations.is_artifact_class(parameter_type):
schema_version = parameter_type.schema_version
input_spec = structures.InputSpec(
type=type_utils.create_bundled_artifact_type(
Expand Down Expand Up @@ -277,7 +278,7 @@ def extract_component_interface(
# `def func(output_path: OutputPath()) -> str: ...`
output_names.add(output_name)
return_ann = signature.return_annotation
if type_annotations.is_artifact(signature.return_annotation):
if type_annotations.is_artifact_class(signature.return_annotation):
output_spec = structures.OutputSpec(
type=type_utils.create_bundled_artifact_type(
return_ann.schema_title, return_ann.schema_version))
Expand Down Expand Up @@ -322,7 +323,7 @@ def _get_command_and_args_for_lightweight_component(
'from kfp import dsl',
'from kfp.dsl import *',
'from typing import *',
]
] + custom_artifact_types.get_custom_artifact_type_import_statements(func)

func_source = _get_function_source_definition(func)
source = textwrap.dedent('''
Expand Down
4 changes: 4 additions & 0 deletions sdk/python/kfp/components/component_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ def test_with_packages_to_install_with_pip_index_url(self):
concat_command = ' '.join(command)
for package in packages_to_install + pip_index_urls:
self.assertTrue(package in concat_command)


if __name__ == '__main__':
unittest.main()
56 changes: 42 additions & 14 deletions sdk/python/kfp/components/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import json
import os
from typing import Any, Callable, Dict, List, Optional, Union

from kfp.components import task_final_status
Expand All @@ -37,30 +38,40 @@ def __init__(self, executor_input: Dict, function_to_execute: Callable):
{}).get('artifacts', {}).items():
artifacts_list = artifacts.get('artifacts')
if artifacts_list:
self._input_artifacts[name] = self._make_input_artifact(
artifacts_list[0])
self._input_artifacts[name] = self.make_artifact(
artifacts_list[0],
name,
self._func,
)

for name, artifacts in self._input.get('outputs',
{}).get('artifacts', {}).items():
artifacts_list = artifacts.get('artifacts')
if artifacts_list:
self._output_artifacts[name] = self._make_output_artifact(
artifacts_list[0])
output_artifact = self.make_artifact(
artifacts_list[0],
name,
self._func,
)
self._output_artifacts[name] = output_artifact
self.makedirs_recursively(output_artifact.path)

self._return_annotation = inspect.signature(
self._func).return_annotation
self._executor_output = {}

@classmethod
def _make_input_artifact(cls, runtime_artifact: Dict):
return artifact_types.create_runtime_artifact(runtime_artifact)
def make_artifact(
self,
runtime_artifact: Dict,
name: str,
func: Callable,
) -> Any:
artifact_cls = func.__annotations__.get(name)
return create_artifact_instance(
runtime_artifact, artifact_cls=artifact_cls)

@classmethod
def _make_output_artifact(cls, runtime_artifact: Dict):
import os
artifact = artifact_types.create_runtime_artifact(runtime_artifact)
os.makedirs(os.path.dirname(artifact.path), exist_ok=True)
return artifact
def makedirs_recursively(self, path: str) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)

def _get_input_artifact(self, name: str):
return self._input_artifacts.get(name)
Expand Down Expand Up @@ -170,7 +181,7 @@ def _is_parameter(cls, annotation: Any) -> bool:
@classmethod
def _is_artifact(cls, annotation: Any) -> bool:
if type(annotation) == type:
return type_annotations.is_artifact(annotation)
return type_annotations.is_artifact_class(annotation)
return False

@classmethod
Expand Down Expand Up @@ -297,3 +308,20 @@ def execute(self):

result = self._func(**func_kwargs)
self._write_executor_output(result)


def create_artifact_instance(
runtime_artifact: Dict,
artifact_cls=artifact_types.Artifact,
) -> type:
"""Creates an artifact class instances from a runtime artifact
dictionary."""
schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '')

artifact_cls = artifact_types._SCHEMA_TITLE_TO_TYPE.get(
schema_title, artifact_cls)
return artifact_cls(
uri=runtime_artifact.get('uri', ''),
name=runtime_artifact.get('name', ''),
metadata=runtime_artifact.get('metadata', {}),
)
Loading

0 comments on commit 158f3e1

Please sign in to comment.