From 158f3e10fa00d353c41332764113981470dc79d2 Mon Sep 17 00:00:00 2001 From: Connor McCarthy Date: Wed, 14 Sep 2022 19:00:40 -0600 Subject: [PATCH] feat(sdk): add runtime logic for custom artifact types (support for custom artifact types pt. 3) (#8233) * add runtime artifact instance creation logic * refactor executor * add executor tests * add custom artifact type import handling and tests * fix artifact class construction * fix custom artifact type in tests * add typing extensions dependency for all python versions * use mock google namespace artifact for tests * remove print statement * update google artifact golden snapshot * resolve some review feedback * remove handling for OutputPath and InputPath custom artifact types; update function names and tests * clarify named tuple tests * update executor tests * add artifact return and named tuple support; refactor; clean tests * implement review feedback; clean up artifact names * move test method --- .../kfp/components/component_factory.py | 11 +- .../kfp/components/component_factory_test.py | 4 + sdk/python/kfp/components/executor.py | 56 +- sdk/python/kfp/components/executor_test.py | 744 +++++++++++++----- sdk/python/kfp/components/structures.py | 1 - .../kfp/components/types/artifact_types.py | 18 - .../components/types/artifact_types_test.py | 99 --- .../components/types/custom_artifact_types.py | 191 +++++ .../types/custom_artifact_types_test.py | 527 +++++++++++++ .../kfp/components/types/type_annotations.py | 4 +- .../components/types/type_annotations_test.py | 10 +- sdk/python/kfp/components/types/type_utils.py | 8 +- sdk/python/requirements.txt | 24 +- .../pipeline_with_google_artifact_type.yaml | 11 +- 14 files changed, 1352 insertions(+), 356 deletions(-) create mode 100644 sdk/python/kfp/components/types/custom_artifact_types.py create mode 100644 sdk/python/kfp/components/types/custom_artifact_types_test.py diff --git a/sdk/python/kfp/components/component_factory.py b/sdk/python/kfp/components/component_factory.py index c93f0d5a313..01bf5f2be49 100644 --- a/sdk/python/kfp/components/component_factory.py +++ b/sdk/python/kfp/components/component_factory.py @@ -28,6 +28,7 @@ from kfp.components import structures from kfp.components.container_component_artifact_channel import \ ContainerComponentArtifactChannel +from kfp.components.types import custom_artifact_types from kfp.components.types import type_annotations from kfp.components.types import type_utils @@ -171,7 +172,7 @@ def extract_component_interface( # parameter_type is type_annotations.Artifact or one of its subclasses. parameter_type = type_annotations.get_io_artifact_class( parameter_type) - if not type_annotations.is_artifact(parameter_type): + if not type_annotations.is_artifact_class(parameter_type): raise ValueError( 'Input[T] and Output[T] are only supported when T is a ' 'subclass of Artifact. Found `{} with type {}`'.format( @@ -203,7 +204,7 @@ def extract_component_interface( ]: io_name = _maybe_make_unique(io_name, output_names) output_names.add(io_name) - if type_annotations.is_artifact(parameter_type): + if type_annotations.is_artifact_class(parameter_type): schema_version = parameter_type.schema_version output_spec = structures.OutputSpec( type=type_utils.create_bundled_artifact_type( @@ -214,7 +215,7 @@ def extract_component_interface( else: io_name = _maybe_make_unique(io_name, input_names) input_names.add(io_name) - if type_annotations.is_artifact(parameter_type): + if type_annotations.is_artifact_class(parameter_type): schema_version = parameter_type.schema_version input_spec = structures.InputSpec( type=type_utils.create_bundled_artifact_type( @@ -277,7 +278,7 @@ def extract_component_interface( # `def func(output_path: OutputPath()) -> str: ...` output_names.add(output_name) return_ann = signature.return_annotation - if type_annotations.is_artifact(signature.return_annotation): + if type_annotations.is_artifact_class(signature.return_annotation): output_spec = structures.OutputSpec( type=type_utils.create_bundled_artifact_type( return_ann.schema_title, return_ann.schema_version)) @@ -322,7 +323,7 @@ def _get_command_and_args_for_lightweight_component( 'from kfp import dsl', 'from kfp.dsl import *', 'from typing import *', - ] + ] + custom_artifact_types.get_custom_artifact_type_import_statements(func) func_source = _get_function_source_definition(func) source = textwrap.dedent(''' diff --git a/sdk/python/kfp/components/component_factory_test.py b/sdk/python/kfp/components/component_factory_test.py index 6b984b6962b..c20da8fc6ff 100644 --- a/sdk/python/kfp/components/component_factory_test.py +++ b/sdk/python/kfp/components/component_factory_test.py @@ -44,3 +44,7 @@ def test_with_packages_to_install_with_pip_index_url(self): concat_command = ' '.join(command) for package in packages_to_install + pip_index_urls: self.assertTrue(package in concat_command) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/python/kfp/components/executor.py b/sdk/python/kfp/components/executor.py index ffac9008f85..2c367131d2d 100644 --- a/sdk/python/kfp/components/executor.py +++ b/sdk/python/kfp/components/executor.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect import json +import os from typing import Any, Callable, Dict, List, Optional, Union from kfp.components import task_final_status @@ -37,30 +38,40 @@ def __init__(self, executor_input: Dict, function_to_execute: Callable): {}).get('artifacts', {}).items(): artifacts_list = artifacts.get('artifacts') if artifacts_list: - self._input_artifacts[name] = self._make_input_artifact( - artifacts_list[0]) + self._input_artifacts[name] = self.make_artifact( + artifacts_list[0], + name, + self._func, + ) for name, artifacts in self._input.get('outputs', {}).get('artifacts', {}).items(): artifacts_list = artifacts.get('artifacts') if artifacts_list: - self._output_artifacts[name] = self._make_output_artifact( - artifacts_list[0]) + output_artifact = self.make_artifact( + artifacts_list[0], + name, + self._func, + ) + self._output_artifacts[name] = output_artifact + self.makedirs_recursively(output_artifact.path) self._return_annotation = inspect.signature( self._func).return_annotation self._executor_output = {} - @classmethod - def _make_input_artifact(cls, runtime_artifact: Dict): - return artifact_types.create_runtime_artifact(runtime_artifact) + def make_artifact( + self, + runtime_artifact: Dict, + name: str, + func: Callable, + ) -> Any: + artifact_cls = func.__annotations__.get(name) + return create_artifact_instance( + runtime_artifact, artifact_cls=artifact_cls) - @classmethod - def _make_output_artifact(cls, runtime_artifact: Dict): - import os - artifact = artifact_types.create_runtime_artifact(runtime_artifact) - os.makedirs(os.path.dirname(artifact.path), exist_ok=True) - return artifact + def makedirs_recursively(self, path: str) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) def _get_input_artifact(self, name: str): return self._input_artifacts.get(name) @@ -170,7 +181,7 @@ def _is_parameter(cls, annotation: Any) -> bool: @classmethod def _is_artifact(cls, annotation: Any) -> bool: if type(annotation) == type: - return type_annotations.is_artifact(annotation) + return type_annotations.is_artifact_class(annotation) return False @classmethod @@ -297,3 +308,20 @@ def execute(self): result = self._func(**func_kwargs) self._write_executor_output(result) + + +def create_artifact_instance( + runtime_artifact: Dict, + artifact_cls=artifact_types.Artifact, +) -> type: + """Creates an artifact class instances from a runtime artifact + dictionary.""" + schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '') + + artifact_cls = artifact_types._SCHEMA_TITLE_TO_TYPE.get( + schema_title, artifact_cls) + return artifact_cls( + uri=runtime_artifact.get('uri', ''), + name=runtime_artifact.get('name', ''), + metadata=runtime_artifact.get('metadata', {}), + ) diff --git a/sdk/python/kfp/components/executor_test.py b/sdk/python/kfp/components/executor_test.py index cc54c005e07..6846a0f92ca 100644 --- a/sdk/python/kfp/components/executor_test.py +++ b/sdk/python/kfp/components/executor_test.py @@ -19,6 +19,7 @@ from typing import Callable, Dict, List, NamedTuple, Optional import unittest +from absl.testing import parameterized from kfp.components import executor from kfp.components.task_final_status import PipelineTaskFinalStatus from kfp.components.types import artifact_types @@ -31,64 +32,6 @@ from kfp.components.types.type_annotations import Output from kfp.components.types.type_annotations import OutputPath -_EXECUTOR_INPUT = """\ -{ - "inputs": { - "parameterValues": { - "input_parameter": "Hello, KFP" - }, - "artifacts": { - "input_artifact_one_path": { - "artifacts": [ - { - "metadata": {}, - "name": "input_artifact_one", - "type": { - "schemaTitle": "system.Dataset" - }, - "uri": "gs://some-bucket/input_artifact_one" - } - ] - } - } - }, - "outputs": { - "artifacts": { - "output_artifact_one_path": { - "artifacts": [ - { - "metadata": {}, - "name": "output_artifact_one", - "type": { - "schemaTitle": "system.Model" - }, - "uri": "gs://some-bucket/output_artifact_one" - } - ] - }, - "output_artifact_two": { - "artifacts": [ - { - "metadata": {}, - "name": "output_artifact_two", - "type": { - "schemaTitle": "system.Metrics" - }, - "uri": "gs://some-bucket/output_artifact_two" - } - ] - } - }, - "parameters": { - "output_parameter_path": { - "outputFile": "%(test_dir)s/gcs/some-bucket/some_task/nested/output_parameter" - } - }, - "outputFile": "%(test_dir)s/output_metadata.json" - } -} -""" - class ExecutorTest(unittest.TestCase): @@ -100,52 +43,153 @@ def setUp(cls): artifact_types._MINIO_LOCAL_MOUNT_PREFIX = cls._test_dir + '/minio/' artifact_types._S3_LOCAL_MOUNT_PREFIX = cls._test_dir + '/s3/' - def _get_executor( - self, - func: Callable, - executor_input: Optional[str] = None) -> executor.Executor: - if executor_input is None: - executor_input = _EXECUTOR_INPUT + def execute_and_load_output_metadata(self, func: Callable, + executor_input: str) -> dict: executor_input_dict = json.loads(executor_input % {'test_dir': self._test_dir}) - return executor.Executor( - executor_input=executor_input_dict, function_to_execute=func) + executor.Executor( + executor_input=executor_input_dict, + function_to_execute=func).execute() + with open(os.path.join(self._test_dir, 'output_metadata.json'), + 'r') as f: + return json.loads(f.read()) - def test_input_parameter(self): + def test_input_and_output_parameters(self): + executor_input = """\ + { + "inputs": { + "parameterValues": { + "input_parameter": "Hello, KFP" + } + }, + "outputs": { + "parameters": { + "Output": { + "outputFile": "gs://some-bucket/output" + } + }, + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ - def test_func(input_parameter: str): + def test_func(input_parameter: str) -> str: self.assertEqual(input_parameter, 'Hello, KFP') + return input_parameter - self._get_executor(test_func).execute() + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertEqual({'parameterValues': { + 'Output': 'Hello, KFP' + }}, output_metadata) - def test_input_artifact(self): + def test_input_artifact_custom_type(self): + executor_input = """\ + { + "inputs": { + "artifacts": { + "input_artifact_one": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", + "type": { + "schemaTitle": "google.VertexDataset" + }, + "uri": "gs://some-bucket/input_artifact_one" + } + ] + } + } + }, + "outputs": { + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ + + class VertexDataset: + schema_title = 'google.VertexDataset' + schema_version = '0.0.0' + + def __init__(self, name: str, uri: str, metadata: dict) -> None: + self.name = name + self.uri = uri + self.metadata = metadata + + @property + def path(self) -> str: + return self.uri.replace('gs://', + artifact_types._GCS_LOCAL_MOUNT_PREFIX) - def test_func(input_artifact_one_path: Input[Dataset]): - self.assertEqual(input_artifact_one_path.uri, + def test_func(input_artifact_one: Input[VertexDataset]): + self.assertEqual(input_artifact_one.uri, 'gs://some-bucket/input_artifact_one') self.assertEqual( - input_artifact_one_path.path, - os.path.join(self._test_dir, 'some-bucket/input_artifact_one')) - self.assertEqual(input_artifact_one_path.name, 'input_artifact_one') - - self._get_executor(test_func).execute() + input_artifact_one.path, + os.path.join(artifact_types._GCS_LOCAL_MOUNT_PREFIX, + 'some-bucket/input_artifact_one')) + self.assertEqual( + input_artifact_one.name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/123' + ) + self.assertIsInstance(input_artifact_one, VertexDataset) - def test_output_artifact(self): + self.execute_and_load_output_metadata(test_func, executor_input) - def test_func(output_artifact_one_path: Output[Model]): - self.assertEqual(output_artifact_one_path.uri, - 'gs://some-bucket/output_artifact_one') + def test_input_artifact(self): + executor_input = """\ + { + "inputs": { + "artifacts": { + "input_artifact_one": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", + "type": { + "schemaTitle": "google.VertexDataset" + }, + "uri": "gs://some-bucket/input_artifact_one" + } + ] + } + } + }, + "outputs": { + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ + def test_func(input_artifact_one: Input[Dataset]): + self.assertEqual(input_artifact_one.uri, + 'gs://some-bucket/input_artifact_one') self.assertEqual( - output_artifact_one_path.path, - os.path.join(self._test_dir, 'some-bucket/output_artifact_one')) - self.assertEqual(output_artifact_one_path.name, - 'output_artifact_one') + input_artifact_one.path, + os.path.join(self._test_dir, 'some-bucket/input_artifact_one')) + self.assertEqual( + input_artifact_one.name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/123' + ) + self.assertIsInstance(input_artifact_one, Dataset) - self._get_executor(test_func).execute() + self.execute_and_load_output_metadata(test_func, executor_input) def test_output_parameter(self): + executor_input = """\ + { + "outputs": { + "parameters": { + "output_parameter_path": { + "outputFile": "%(test_dir)s/gcs/some-bucket/some_task/nested/output_parameter" + } + }, + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ def test_func(output_parameter_path: OutputPath(str)): # Test that output parameters just use the passed in filename. @@ -155,27 +199,92 @@ def test_func(output_parameter_path: OutputPath(str)): with open(output_parameter_path, 'w') as f: f.write('Hello, World!') - self._get_executor(test_func).execute() + self.execute_and_load_output_metadata(test_func, executor_input) def test_input_path_artifact(self): + executor_input = """\ + { + "inputs": { + "artifacts": { + "input_artifact_one_path": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", + "type": { + "schemaTitle": "system.Dataset" + }, + "uri": "gs://some-bucket/input_artifact_one" + } + ] + } + } + }, + "outputs": { + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ def test_func(input_artifact_one_path: InputPath('Dataset')): self.assertEqual( input_artifact_one_path, os.path.join(self._test_dir, 'some-bucket/input_artifact_one')) - self._get_executor(test_func).execute() + self.execute_and_load_output_metadata(test_func, executor_input) def test_output_path_artifact(self): + executor_input = """\ + { + "outputs": { + "artifacts": { + "output_artifact_one_path": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", + "type": { + "schemaTitle": "system.Model" + }, + "uri": "gs://some-bucket/output_artifact_one" + } + ] + } + }, + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ def test_func(output_artifact_one_path: OutputPath('Model')): self.assertEqual( output_artifact_one_path, os.path.join(self._test_dir, 'some-bucket/output_artifact_one')) - self._get_executor(test_func).execute() + self.execute_and_load_output_metadata(test_func, executor_input) def test_output_metadata(self): + executor_input = """\ + { + "outputs": { + "artifacts": { + "output_artifact_two": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", + "type": { + "schemaTitle": "system.Metrics" + }, + "uri": "gs://some-bucket/output_artifact_two" + } + ] + } + }, + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ def test_func(output_artifact_two: Output[Metrics]): output_artifact_two.metadata['key_1'] = 'value_1' @@ -185,24 +294,18 @@ def test_func(output_artifact_two: Output[Metrics]): # log_metric works here since the schema is specified as Metrics. output_artifact_two.log_metric('metric', 0.9) - self._get_executor(test_func).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual( output_metadata, { 'artifacts': { - 'output_artifact_one_path': { - 'artifacts': [{ - 'name': 'output_artifact_one', - 'uri': 'gs://some-bucket/output_artifact_one', - 'metadata': {} - }] - }, 'output_artifact_two': { 'artifacts': [{ - 'name': 'output_artifact_two', - 'uri': 'new-uri', + 'name': + 'projects/123/locations/us-central1/metadataStores/default/artifacts/123', + 'uri': + 'new-uri', 'metadata': { 'key_1': 'value_1', 'key_2': 2, @@ -219,13 +322,13 @@ def test_function_string_output(self): "inputs": { "parameterValues": { "first_message": "Hello", - "second_message": "", + "second_message": ", ", "third_message": "World" } }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -239,15 +342,13 @@ def test_func( second_message: str, third_message: str, ) -> str: - return first_message + ', ' + second_message + ', ' + third_message + return first_message + second_message + third_message - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) self.assertDictEqual(output_metadata, { 'parameterValues': { - 'Output': 'Hello, , World' + 'Output': 'Hello, World' }, }) @@ -262,7 +363,7 @@ def test_function_with_int_output(self): }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -274,10 +375,8 @@ def test_function_with_int_output(self): def test_func(first: int, second: int) -> int: return first + second - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) self.assertDictEqual(output_metadata, { 'parameterValues': { 'Output': 42 @@ -295,7 +394,7 @@ def test_function_with_float_output(self): }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -307,10 +406,9 @@ def test_function_with_float_output(self): def test_func(first: float, second: float) -> float: return first + second - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual(output_metadata, { 'parameterValues': { 'Output': 1.2 @@ -328,7 +426,7 @@ def test_function_with_list_output(self): }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -340,10 +438,9 @@ def test_function_with_list_output(self): def test_func(first: int, second: int) -> List: return [first, second] - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual(output_metadata, { 'parameterValues': { 'Output': [40, 2] @@ -361,7 +458,7 @@ def test_function_with_dict_output(self): }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -373,10 +470,9 @@ def test_function_with_dict_output(self): def test_func(first: int, second: int) -> Dict: return {'first': first, 'second': second} - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual(output_metadata, { 'parameterValues': { 'Output': { @@ -397,7 +493,7 @@ def test_function_with_typed_list_output(self): }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -409,10 +505,9 @@ def test_function_with_typed_list_output(self): def test_func(first: int, second: int) -> List[int]: return [first, second] - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual(output_metadata, { 'parameterValues': { 'Output': [40, 2] @@ -430,7 +525,7 @@ def test_function_with_typed_dict_output(self): }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -442,10 +537,9 @@ def test_function_with_typed_dict_output(self): def test_func(first: int, second: int) -> Dict[str, int]: return {'first': first, 'second': second} - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual(output_metadata, { 'parameterValues': { 'Output': { @@ -455,7 +549,7 @@ def test_func(first: int, second: int) -> Dict[str, int]: }, }) - def test_artifact_output(self): + def test_artifact_output1(self): executor_input = """\ { "inputs": { @@ -466,10 +560,11 @@ def test_artifact_output(self): }, "outputs": { "artifacts": { - "Output": { + "output": { "artifacts": [ { - "name": "output", + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", "type": { "schemaTitle": "system.Artifact" }, @@ -478,65 +573,169 @@ def test_artifact_output(self): ] } }, + "parameters": { + "Output": { + "outputFile": "gs://some-bucket/output" + } + }, "outputFile": "%(test_dir)s/output_metadata.json" } } """ - def test_func(first: str, second: str) -> Artifact: + def test_func(first: str, second: str, output: Output[Artifact]) -> str: + with open(output.path, 'w') as f: + f.write('artifact output') return first + ', ' + second - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual( output_metadata, { 'artifacts': { - 'Output': { + 'output': { 'artifacts': [{ 'metadata': {}, - 'name': 'output', - 'uri': 'gs://some-bucket/output' + 'name': + 'projects/123/locations/us-central1/metadataStores/default/artifacts/123', + 'uri': + 'gs://some-bucket/output' }] } + }, + 'parameterValues': { + 'Output': 'Hello, World' } }) with open(os.path.join(self._test_dir, 'some-bucket/output'), 'r') as f: artifact_payload = f.read() - self.assertEqual(artifact_payload, 'Hello, World') + self.assertEqual(artifact_payload, 'artifact output') - def test_named_tuple_output(self): + def test_artifact_output2(self): executor_input = """\ { + "inputs": { + "parameterValues": { + "first": "Hello", + "second": "World" + } + }, "outputs": { "artifacts": { - "output_dataset": { + "Output": { "artifacts": [ { - "name": "output_dataset", + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", "type": { - "schemaTitle": "system.Dataset" + "schemaTitle": "system.Artifact" }, - "uri": "gs://some-bucket/output_dataset" + "uri": "gs://some-bucket/output" } ] } }, - "parameters": { - "output_int": { - "outputFile": "gs://some-bucket/output_int" - }, - "output_string": { - "outputFile": "gs://some-bucket/output_string" - } - }, "outputFile": "%(test_dir)s/output_metadata.json" } } """ + def test_func(first: str, second: str) -> Artifact: + return first + ', ' + second + + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + + self.assertDictEqual( + output_metadata, { + 'artifacts': { + 'Output': { + 'artifacts': [{ + 'metadata': {}, + 'name': + 'projects/123/locations/us-central1/metadataStores/default/artifacts/123', + 'uri': + 'gs://some-bucket/output' + }] + } + }, + }) + + with open(os.path.join(self._test_dir, 'some-bucket/output'), 'r') as f: + artifact_payload = f.read() + self.assertEqual(artifact_payload, 'Hello, World') + + def test_output_artifact3(self): + executor_input = """\ + { + "outputs": { + "artifacts": { + "output_artifact_one": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", + "type": { + "schemaTitle": "system.Model" + }, + "uri": "gs://some-bucket/output_artifact_one" + } + ] + } + }, + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ + + def test_func(output_artifact_one: Output[Model]): + self.assertEqual(output_artifact_one.uri, + 'gs://some-bucket/output_artifact_one') + + self.assertEqual( + output_artifact_one.path, + os.path.join(self._test_dir, 'some-bucket/output_artifact_one')) + self.assertEqual( + output_artifact_one.name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/123' + ) + self.assertIsInstance(output_artifact_one, Model) + + self.execute_and_load_output_metadata(test_func, executor_input) + + def test_named_tuple_output(self): + executor_input = """\ + { + "outputs": { + "artifacts": { + "output_dataset": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/123", + "type": { + "schemaTitle": "system.Dataset" + }, + "uri": "gs://some-bucket/output_dataset" + } + ] + } + }, + "parameters": { + "output_int": { + "outputFile": "gs://some-bucket/output_int" + }, + "output_string": { + "outputFile": "gs://some-bucket/output_string" + } + }, + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ + # Functions returning named tuples should work. def func_returning_named_tuple() -> NamedTuple('Outputs', [ ('output_dataset', Dataset), @@ -559,19 +758,19 @@ def func_returning_plain_tuple() -> NamedTuple('Outputs', [ for test_func in [ func_returning_named_tuple, func_returning_plain_tuple ]: - self._get_executor(test_func, executor_input).execute() - with open( - os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual( output_metadata, { 'artifacts': { 'output_dataset': { 'artifacts': [{ 'metadata': {}, - 'name': 'output_dataset', - 'uri': 'gs://some-bucket/output_dataset' + 'name': + 'projects/123/locations/us-central1/metadataStores/default/artifacts/123', + 'uri': + 'gs://some-bucket/output_dataset' }] } }, @@ -589,29 +788,29 @@ def func_returning_plain_tuple() -> NamedTuple('Outputs', [ def test_function_with_optional_inputs(self): executor_input = """\ - { - "inputs": { - "parameterValues": { - "first_message": "Hello", - "second_message": "World" - } - }, - "outputs": { - "parameters": { - "output": { - "outputFile": "gs://some-bucket/output" + { + "inputs": { + "parameterValues": { + "first_message": "Hello", + "second_message": "World" + } + }, + "outputs": { + "parameters": { + "Output": { + "outputFile": "gs://some-bucket/output" + } + }, + "outputFile": "%(test_dir)s/output_metadata.json" } - }, - "outputFile": "%(test_dir)s/output_metadata.json" - } - } - """ + } + """ def test_func( first_message: str = 'default value', second_message: Optional[str] = None, third_message: Optional[str] = None, - forth_argument: str = 'abc', + fourth_argument: str = 'abc', fifth_argument: int = 100, sixth_argument: float = 1.23, seventh_argument: bool = True, @@ -621,17 +820,16 @@ def test_func( return (f'{first_message} ({type(first_message)}), ' f'{second_message} ({type(second_message)}), ' f'{third_message} ({type(third_message)}), ' - f'{forth_argument} ({type(forth_argument)}), ' + f'{fourth_argument} ({type(fourth_argument)}), ' f'{fifth_argument} ({type(fifth_argument)}), ' f'{sixth_argument} ({type(sixth_argument)}), ' f'{seventh_argument} ({type(seventh_argument)}), ' f'{eighth_argument} ({type(eighth_argument)}), ' f'{ninth_argument} ({type(ninth_argument)}).') - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual( output_metadata, { 'parameterValues': { @@ -657,7 +855,7 @@ def test_function_with_pipeline_task_final_status(self): }, "outputs": { "parameters": { - "output": { + "Output": { "outputFile": "gs://some-bucket/output" } }, @@ -673,10 +871,9 @@ def test_func(status: PipelineTaskFinalStatus) -> str: f'Error code: {status.error_code}\n' f'Error message: {status.error_message}') - self._get_executor(test_func, executor_input).execute() - with open(os.path.join(self._test_dir, 'output_metadata.json'), - 'r') as f: - output_metadata = json.loads(f.read()) + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + self.assertDictEqual( output_metadata, { 'parameterValues': { @@ -690,5 +887,154 @@ def test_func(status: PipelineTaskFinalStatus) -> str: }) +class VertexDataset: + schema_title = 'google.VertexDataset' + schema_version = '0.0.0' + + def __init__(self, name: str, uri: str, metadata: dict) -> None: + self.name = name + self.uri = uri + self.metadata = metadata + + @property + def path(self) -> str: + return self.uri.replace('gs://', artifact_types._GCS_LOCAL_MOUNT_PREFIX) + + +class TestDictToArtifact(parameterized.TestCase): + + @parameterized.parameters( + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.Artifact' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': artifact_types.Artifact, + 'expected_type': artifact_types.Artifact, + }, + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.Model' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': artifact_types.Model, + 'expected_type': artifact_types.Model, + }, + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.Dataset' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': artifact_types.Dataset, + 'expected_type': artifact_types.Dataset, + }, + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.Metrics' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': artifact_types.Metrics, + 'expected_type': artifact_types.Metrics, + }, + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.ClassificationMetrics' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': artifact_types.ClassificationMetrics, + 'expected_type': artifact_types.ClassificationMetrics, + }, + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.SlicedClassificationMetrics' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': artifact_types.SlicedClassificationMetrics, + 'expected_type': artifact_types.SlicedClassificationMetrics, + }, + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.HTML' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': None, + 'expected_type': artifact_types.HTML, + }, + { + 'runtime_artifact': { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'system.Markdown' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + }, + 'artifact_cls': None, + 'expected_type': artifact_types.Markdown, + }, + ) + def test_dict_to_artifact_kfp_artifact( + self, + runtime_artifact, + artifact_cls, + expected_type, + ): + # with artifact_cls + self.assertIsInstance( + executor.create_artifact_instance( + runtime_artifact, artifact_cls=artifact_cls), expected_type) + + # without artifact_cls + self.assertIsInstance( + executor.create_artifact_instance(runtime_artifact), expected_type) + + def test_dict_to_artifact_nonkfp_artifact(self): + runtime_artifact = { + 'metadata': {}, + 'name': 'input_artifact_one', + 'type': { + 'schemaTitle': 'google.VertexDataset' + }, + 'uri': 'gs://some-bucket/input_artifact_one' + } + # with artifact_cls + self.assertIsInstance( + executor.create_artifact_instance( + runtime_artifact, artifact_cls=VertexDataset), VertexDataset) + + # without artifact_cls + self.assertIsInstance( + executor.create_artifact_instance(runtime_artifact), + artifact_types.Artifact) + + if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/components/structures.py b/sdk/python/kfp/components/structures.py index 7900df48cfa..6c5a81a5b79 100644 --- a/sdk/python/kfp/components/structures.py +++ b/sdk/python/kfp/components/structures.py @@ -629,7 +629,6 @@ def from_v1_component_spec( inputs = {} for spec in component_dict.get('inputs', []): type_ = spec.get('type') - print('TYPE', type_) if isinstance(type_, str) and type_ == 'PipelineTaskFinalStatus': inputs[utils.sanitize_input_name( diff --git a/sdk/python/kfp/components/types/artifact_types.py b/sdk/python/kfp/components/types/artifact_types.py index 7b783aa6598..4ab2a82afe7 100644 --- a/sdk/python/kfp/components/types/artifact_types.py +++ b/sdk/python/kfp/components/types/artifact_types.py @@ -510,21 +510,3 @@ def __init__(self, Markdown, ] } - - -def create_runtime_artifact(runtime_artifact: Dict) -> Artifact: - """Creates an Artifact instance from the specified RuntimeArtifact. - - Args: - runtime_artifact: Dictionary representing JSON-encoded RuntimeArtifact. - """ - schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '') - - artifact_type = _SCHEMA_TITLE_TO_TYPE.get(schema_title) - if not artifact_type: - artifact_type = Artifact - return artifact_type( - uri=runtime_artifact.get('uri', ''), - name=runtime_artifact.get('name', ''), - metadata=runtime_artifact.get('metadata', {}), - ) diff --git a/sdk/python/kfp/components/types/artifact_types_test.py b/sdk/python/kfp/components/types/artifact_types_test.py index 8f906f66505..517d5f9b4ed 100644 --- a/sdk/python/kfp/components/types/artifact_types_test.py +++ b/sdk/python/kfp/components/types/artifact_types_test.py @@ -56,105 +56,6 @@ def test_complex_metrics_bulk_loading(self): expected_json = json.load(json_file) self.assertEqual(expected_json, metrics.metadata) - @parameterized.parameters( - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.Artifact' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.Artifact, - }, - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.Model' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.Model, - }, - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.Dataset' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.Dataset, - }, - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.Metrics' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.Metrics, - }, - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.ClassificationMetrics' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.ClassificationMetrics, - }, - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.SlicedClassificationMetrics' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.SlicedClassificationMetrics, - }, - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.HTML' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.HTML, - }, - { - 'runtime_artifact': { - 'metadata': {}, - 'name': 'input_artifact_one', - 'type': { - 'schemaTitle': 'system.Markdown' - }, - 'uri': 'gs://some-bucket/input_artifact_one' - }, - 'expected_type': artifact_types.Markdown, - }, - ) - def test_create_runtime_artifact( - self, - runtime_artifact, - expected_type, - ): - self.assertIsInstance( - artifact_types.create_runtime_artifact(runtime_artifact), - expected_type) - if __name__ == '__main__': unittest.main() diff --git a/sdk/python/kfp/components/types/custom_artifact_types.py b/sdk/python/kfp/components/types/custom_artifact_types.py new file mode 100644 index 00000000000..2fba20b0c6d --- /dev/null +++ b/sdk/python/kfp/components/types/custom_artifact_types.py @@ -0,0 +1,191 @@ +# Copyright 2022 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import inspect +from typing import Callable, Dict, List, Union + +from kfp.components import component_factory +from kfp.components.types import type_annotations +from kfp.components.types import type_utils + +RETURN_PREFIX = 'return-' + + +def get_custom_artifact_type_import_statements(func: Callable) -> List[str]: + """Gets a list of custom artifact type import statements from a lightweight + Python component function.""" + artifact_imports = get_custom_artifact_import_items_from_function(func) + imports_source = [] + for obj_str in artifact_imports: + if '.' in obj_str: + path, name = obj_str.rsplit('.', 1) + imports_source.append(f'from {path} import {name}') + else: + imports_source.append(f'import {obj_str}') + return imports_source + + +def get_param_to_custom_artifact_class(func: Callable) -> Dict[str, type]: + """Gets a map of parameter names to custom artifact classes. + + Return key is 'return-' for normal returns and 'return-' for + typing.NamedTuple returns. + """ + param_to_artifact_cls: Dict[str, type] = {} + kfp_artifact_classes = set(type_utils._ARTIFACT_CLASSES_MAPPING.values()) + + signature = inspect.signature(func) + for name, param in signature.parameters.items(): + annotation = param.annotation + if type_annotations.is_artifact_annotation(annotation): + artifact_class = type_annotations.get_io_artifact_class(annotation) + if artifact_class not in kfp_artifact_classes: + param_to_artifact_cls[name] = artifact_class + elif type_annotations.is_artifact_class(annotation): + param_to_artifact_cls[name] = annotation + if artifact_class not in kfp_artifact_classes: + param_to_artifact_cls[name] = artifact_class + + return_annotation = signature.return_annotation + + if return_annotation is inspect.Signature.empty: + pass + + elif type_utils.is_typed_named_tuple_annotation(return_annotation): + for name, annotation in return_annotation.__annotations__.items(): + if type_annotations.is_artifact_class( + annotation) and annotation not in kfp_artifact_classes: + param_to_artifact_cls[f'{RETURN_PREFIX}{name}'] = annotation + + elif type_annotations.is_artifact_class( + return_annotation + ) and return_annotation not in kfp_artifact_classes: + param_to_artifact_cls[RETURN_PREFIX] = return_annotation + + return param_to_artifact_cls + + +def get_full_qualname_for_artifact(obj: type) -> str: + """Gets the fully qualified name for an object. For example, for class Foo + in module bar.baz, this function returns bar.baz.Foo. + + Note: typing.get_type_hints purports to do the same thing, but it behaves + differently when executed within the scope of a test, so preferring this + approach instead. + + Args: + obj: The class or module for which to get the fully qualified name. + + Returns: + The fully qualified name for the class. + """ + module = obj.__module__ + name = obj.__qualname__ + if module is not None: + name = module + '.' + name + return name + + +def get_symbol_import_path(artifact_class_base_symbol: str, + qualname: str) -> str: + """Gets the fully qualified name of the symbol that must be imported for + the custom artifact type annotation to be referenced successfully. + + Args: + artifact_class_base_symbol: The base symbol from which the artifact class is referenced (e.g., aiplatform for aiplatform.VertexDataset). + qualname: The fully qualified type annotation name as a string. + + Returns: + The fully qualified names of the module or type to import. + """ + split_qualname = qualname.split('.') + if artifact_class_base_symbol in split_qualname: + name_to_import = '.'.join( + split_qualname[:split_qualname.index(artifact_class_base_symbol) + + 1]) + else: + raise TypeError( + f"Module or type name aliases are not supported. You appear to be using an alias in your type annotation: '{qualname}'. This may be due to use of an 'as' statement in an import statement or a reassignment of a module or type to a new name. Reference the module and/or type using the name as defined in the source from which the module or type is imported." + ) + return name_to_import + + +def traverse_ast_node_values_to_get_id(obj: Union[ast.Slice, None]) -> str: + while not hasattr(obj, 'id'): + obj = getattr(obj, 'value') + return obj.id + + +def get_custom_artifact_base_symbol_for_parameter(func: Callable, + arg_name: str) -> str: + """Gets the symbol required for the custom artifact type annotation to be + referenced correctly.""" + module_node = ast.parse( + component_factory._get_function_source_definition(func)) + args = module_node.body[0].args.args + args = {arg.arg: arg for arg in args} + annotation = args[arg_name].annotation + return traverse_ast_node_values_to_get_id(annotation.slice) + + +def get_custom_artifact_base_symbol_for_return(func: Callable, + return_name: str) -> str: + """Gets the symbol required for the custom artifact type return annotation + to be referenced correctly.""" + module_node = ast.parse( + component_factory._get_function_source_definition(func)) + return_ann = module_node.body[0].returns + + if return_name == RETURN_PREFIX: + if isinstance(return_ann, (ast.Name, ast.Attribute)): + return traverse_ast_node_values_to_get_id(return_ann) + elif isinstance(return_ann, ast.Call): + func = return_ann.func + # handles NamedTuple and typing.NamedTuple + if (isinstance(func, ast.Attribute) and func.value.id == 'typing' and + func.attr == 'NamedTuple') or (isinstance(func, ast.Name) and + func.id == 'NamedTuple'): + nt_field_list = return_ann.args[1].elts + for el in nt_field_list: + if f'{RETURN_PREFIX}{el.elts[0].s}' == return_name: + return traverse_ast_node_values_to_get_id(el.elts[1]) + + raise TypeError(f"Unexpected type annotation '{return_ann}' for {func}.") + + +def get_custom_artifact_import_items_from_function(func: Callable) -> List[str]: + """Gets the fully qualified name of the symbol that must be imported for + the custom artifact type annotation to be referenced successfully from a + component function.""" + + param_to_ann_obj = get_param_to_custom_artifact_class(func) + import_items = [] + for param_name, artifact_class in param_to_ann_obj.items(): + + base_symbol = get_custom_artifact_base_symbol_for_return( + func, param_name + ) if param_name.startswith( + RETURN_PREFIX) else get_custom_artifact_base_symbol_for_parameter( + func, param_name) + artifact_qualname = get_full_qualname_for_artifact(artifact_class) + symbol_import_path = get_symbol_import_path(base_symbol, + artifact_qualname) + + # could use set here, but want to be have deterministic import ordering + # in compilation + if symbol_import_path not in import_items: + import_items.append(symbol_import_path) + + return import_items diff --git a/sdk/python/kfp/components/types/custom_artifact_types_test.py b/sdk/python/kfp/components/types/custom_artifact_types_test.py new file mode 100644 index 00000000000..c4de629df45 --- /dev/null +++ b/sdk/python/kfp/components/types/custom_artifact_types_test.py @@ -0,0 +1,527 @@ +# Copyright 2022 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import sys +import tempfile +import textwrap +import typing +from typing import Any +import unittest + +from absl.testing import parameterized +import kfp +from kfp import dsl +from kfp.components.types import artifact_types +from kfp.components.types import custom_artifact_types +from kfp.components.types.artifact_types import Artifact +from kfp.components.types.artifact_types import Dataset +from kfp.components.types.type_annotations import Input +from kfp.components.types.type_annotations import InputPath +from kfp.components.types.type_annotations import Output +from kfp.components.types.type_annotations import OutputPath + +Alias = Artifact +artifact_types_alias = artifact_types + + +class _TestCaseWithThirdPartyPackage(parameterized.TestCase): + + @classmethod + def setUpClass(cls): + + class VertexDataset: + schema_title = 'google.VertexDataset' + schema_version = '0.0.0' + + class_source = textwrap.dedent(inspect.getsource(VertexDataset)) + + tmp_dir = tempfile.TemporaryDirectory() + with open(os.path.join(tmp_dir.name, 'aiplatform.py'), 'w') as f: + f.write(class_source) + sys.path.append(tmp_dir.name) + cls.tmp_dir = tmp_dir + + @classmethod + def teardownClass(cls): + sys.path.pop() + cls.tmp_dir.cleanup() + + +class TestGetParamToCustomArtifactClass(_TestCaseWithThirdPartyPackage): + + def test_no_ann(self): + + def func(): + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_primitives(self): + + def func(a: str, b: int) -> str: + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_input_path(self): + + def func(a: InputPath(str), b: InputPath('Dataset')) -> str: + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_output_path(self): + + def func(a: OutputPath(str), b: OutputPath('Dataset')) -> str: + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_input_kfp_artifact(self): + + def func(a: Input[Artifact]): + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_output_kfp_artifact(self): + + def func(a: Output[Artifact]): + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_return_kfp_artifact1(self): + + def func() -> Artifact: + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_return_kfp_artifact2(self): + + def func() -> dsl.Artifact: + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_named_tuple_primitives(self): + + def func() -> typing.NamedTuple('Outputs', [ + ('a', str), + ('b', int), + ]): + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {}) + + def test_input_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func( + a: Input[aiplatform.VertexDataset], + b: Input[VertexDataset], + c: dsl.Input[aiplatform.VertexDataset], + d: kfp.dsl.Input[VertexDataset], + ): + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual( + actual, { + 'a': aiplatform.VertexDataset, + 'b': aiplatform.VertexDataset, + 'c': aiplatform.VertexDataset, + 'd': aiplatform.VertexDataset, + }) + + def test_output_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func( + a: Output[aiplatform.VertexDataset], + b: Output[VertexDataset], + c: dsl.Output[aiplatform.VertexDataset], + d: kfp.dsl.Output[VertexDataset], + ): + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual( + actual, { + 'a': aiplatform.VertexDataset, + 'b': aiplatform.VertexDataset, + 'c': aiplatform.VertexDataset, + 'd': aiplatform.VertexDataset, + }) + + def test_return_google_artifact1(self): + import aiplatform + from aiplatform import VertexDataset + + def func() -> VertexDataset: + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {'return-': aiplatform.VertexDataset}) + + def test_return_google_artifact2(self): + import aiplatform + + def func() -> aiplatform.VertexDataset: + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual(actual, {'return-': aiplatform.VertexDataset}) + + def test_named_tuple_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func() -> typing.NamedTuple('Outputs', [ + ('a', aiplatform.VertexDataset), + ('b', VertexDataset), + ]): + pass + + actual = custom_artifact_types.get_param_to_custom_artifact_class(func) + self.assertEqual( + actual, { + 'return-a': aiplatform.VertexDataset, + 'return-b': aiplatform.VertexDataset, + }) + + +class TestGetFullQualnameForArtifact(_TestCaseWithThirdPartyPackage): + # only gets called on artifacts, so don't need to test on all types + @parameterized.parameters([ + (Alias, 'kfp.components.types.artifact_types.Artifact'), + (Artifact, 'kfp.components.types.artifact_types.Artifact'), + (Dataset, 'kfp.components.types.artifact_types.Dataset'), + ]) + def test(self, obj: Any, expected_qualname: str): + self.assertEqual( + custom_artifact_types.get_full_qualname_for_artifact(obj), + expected_qualname) + + def test_aiplatform_artifact(self): + import aiplatform + self.assertEqual( + custom_artifact_types.get_full_qualname_for_artifact( + aiplatform.VertexDataset), 'aiplatform.VertexDataset') + + +class TestGetSymbolImportPath(parameterized.TestCase): + + @parameterized.parameters([ + { + 'artifact_class_base_symbol': 'aiplatform', + 'qualname': 'aiplatform.VertexDataset', + 'expected': 'aiplatform' + }, + { + 'artifact_class_base_symbol': 'VertexDataset', + 'qualname': 'aiplatform.VertexDataset', + 'expected': 'aiplatform.VertexDataset' + }, + { + 'artifact_class_base_symbol': 'e', + 'qualname': 'a.b.c.d.e', + 'expected': 'a.b.c.d.e' + }, + { + 'artifact_class_base_symbol': 'c', + 'qualname': 'a.b.c.d.e', + 'expected': 'a.b.c' + }, + ]) + def test(self, artifact_class_base_symbol: str, qualname: str, + expected: str): + actual = custom_artifact_types.get_symbol_import_path( + artifact_class_base_symbol, qualname) + self.assertEqual(actual, expected) + + +class TestGetCustomArtifactBaseSymbolForParameter(_TestCaseWithThirdPartyPackage + ): + + def test_input_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func( + a: Input[aiplatform.VertexDataset], + b: Input[VertexDataset], + c: dsl.Input[aiplatform.VertexDataset], + d: kfp.dsl.Input[VertexDataset], + ): + pass + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'a') + self.assertEqual(actual, 'aiplatform') + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'b') + self.assertEqual(actual, 'VertexDataset') + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'c') + self.assertEqual(actual, 'aiplatform') + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'd') + self.assertEqual(actual, 'VertexDataset') + + def test_output_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func( + a: Output[aiplatform.VertexDataset], + b: Output[VertexDataset], + c: dsl.Output[aiplatform.VertexDataset], + d: kfp.dsl.Output[VertexDataset], + ): + pass + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'a') + self.assertEqual(actual, 'aiplatform') + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'b') + self.assertEqual(actual, 'VertexDataset') + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'c') + self.assertEqual(actual, 'aiplatform') + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_parameter( + func, 'd') + self.assertEqual(actual, 'VertexDataset') + + +class TestGetCustomArtifactBaseSymbolForReturn(_TestCaseWithThirdPartyPackage): + + def test_return_google_artifact1(self): + from aiplatform import VertexDataset + + def func() -> VertexDataset: + pass + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return( + func, 'return-') + self.assertEqual(actual, 'VertexDataset') + + def test_return_google_artifact2(self): + import aiplatform + + def func() -> aiplatform.VertexDataset: + pass + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return( + func, 'return-') + self.assertEqual(actual, 'aiplatform') + + def test_named_tuple_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func() -> typing.NamedTuple('Outputs', [ + ('a', aiplatform.VertexDataset), + ('b', VertexDataset), + ]): + pass + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return( + func, 'return-a') + self.assertEqual(actual, 'aiplatform') + + actual = custom_artifact_types.get_custom_artifact_base_symbol_for_return( + func, 'return-b') + self.assertEqual(actual, 'VertexDataset') + + +class TestGetCustomArtifactImportItemsFromFunction( + _TestCaseWithThirdPartyPackage): + + def test_no_ann(self): + + def func(): + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_primitives(self): + + def func(a: str, b: int) -> str: + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_input_path(self): + + def func(a: InputPath(str), b: InputPath('Dataset')) -> str: + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_output_path(self): + + def func(a: OutputPath(str), b: OutputPath('Dataset')) -> str: + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_input_kfp_artifact(self): + + def func(a: Input[Artifact]): + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_output_kfp_artifact(self): + + def func(a: Output[Artifact]): + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_return_kfp_artifact1(self): + + def func() -> Artifact: + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_return_kfp_artifact2(self): + + def func() -> dsl.Artifact: + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_named_tuple_primitives(self): + + def func() -> typing.NamedTuple('Outputs', [ + ('a', str), + ('b', int), + ]): + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, []) + + def test_input_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func( + a: Input[aiplatform.VertexDataset], + b: Input[VertexDataset], + c: dsl.Input[aiplatform.VertexDataset], + d: kfp.dsl.Input[VertexDataset], + ): + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, ['aiplatform', 'aiplatform.VertexDataset']) + + def test_output_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func( + a: Output[aiplatform.VertexDataset], + b: Output[VertexDataset], + c: dsl.Output[aiplatform.VertexDataset], + d: kfp.dsl.Output[VertexDataset], + ): + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + + self.assertEqual(actual, ['aiplatform', 'aiplatform.VertexDataset']) + + def test_return_google_artifact1(self): + import aiplatform + from aiplatform import VertexDataset + + def func() -> VertexDataset: + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, ['aiplatform.VertexDataset']) + + def test_return_google_artifact2(self): + import aiplatform + + def func() -> aiplatform.VertexDataset: + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, ['aiplatform']) + + def test_named_tuple_google_artifact(self): + import aiplatform + from aiplatform import VertexDataset + + def func() -> typing.NamedTuple('Outputs', [ + ('a', aiplatform.VertexDataset), + ('b', VertexDataset), + ]): + pass + + actual = custom_artifact_types.get_custom_artifact_import_items_from_function( + func) + self.assertEqual(actual, ['aiplatform', 'aiplatform.VertexDataset']) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/python/kfp/components/types/type_annotations.py b/sdk/python/kfp/components/types/type_annotations.py index 95e54a6a150..74ed3170786 100644 --- a/sdk/python/kfp/components/types/type_annotations.py +++ b/sdk/python/kfp/components/types/type_annotations.py @@ -106,7 +106,7 @@ def __eq__(self, other): def construct_type_for_inputpath_or_outputpath( type_: Union[str, Type, None]) -> Union[str, None]: - if type_annotations.is_artifact(type_): + if type_annotations.is_artifact_class(type_): return type_utils.create_bundled_artifact_type(type_.schema_title, type_.schema_version) elif isinstance( @@ -274,7 +274,7 @@ def get_short_type_name(type_name: str) -> str: return type_name -def is_artifact(artifact_class_or_instance: Type) -> bool: +def is_artifact_class(artifact_class_or_instance: Type) -> bool: # we do not yet support non-pre-registered custom artifact types with instance_schema attribute return hasattr(artifact_class_or_instance, 'schema_title') and hasattr( artifact_class_or_instance, 'schema_version') diff --git a/sdk/python/kfp/components/types/type_annotations_test.py b/sdk/python/kfp/components/types/type_annotations_test.py index 645ed78274e..abd5b1680e3 100644 --- a/sdk/python/kfp/components/types/type_annotations_test.py +++ b/sdk/python/kfp/components/types/type_annotations_test.py @@ -161,31 +161,31 @@ class TestIsArtifact(parameterized.TestCase): 'obj': obj } for obj in artifact_types._SCHEMA_TITLE_TO_TYPE.values()]) def test_true_class(self, obj): - self.assertTrue(type_annotations.is_artifact(obj)) + self.assertTrue(type_annotations.is_artifact_class(obj)) @parameterized.parameters([{ 'obj': obj(name='name', uri='uri', metadata={}) } for obj in artifact_types._SCHEMA_TITLE_TO_TYPE.values()]) def test_true_instance(self, obj): - self.assertTrue(type_annotations.is_artifact(obj)) + self.assertTrue(type_annotations.is_artifact_class(obj)) @parameterized.parameters([{'obj': 'string'}, {'obj': 1}, {'obj': int}]) def test_false(self, obj): - self.assertFalse(type_annotations.is_artifact(obj)) + self.assertFalse(type_annotations.is_artifact_class(obj)) def test_false_no_schema_title(self): class NotArtifact: schema_version = '' - self.assertFalse(type_annotations.is_artifact(NotArtifact)) + self.assertFalse(type_annotations.is_artifact_class(NotArtifact)) def test_false_no_schema_version(self): class NotArtifact: schema_title = '' - self.assertFalse(type_annotations.is_artifact(NotArtifact)) + self.assertFalse(type_annotations.is_artifact_class(NotArtifact)) if __name__ == '__main__': diff --git a/sdk/python/kfp/components/types/type_utils.py b/sdk/python/kfp/components/types/type_utils.py index 1fac3359d1a..d37246bee87 100644 --- a/sdk/python/kfp/components/types/type_utils.py +++ b/sdk/python/kfp/components/types/type_utils.py @@ -28,6 +28,7 @@ # ComponentSpec I/O types to DSL ontology artifact classes mapping. _ARTIFACT_CLASSES_MAPPING = { + 'artifact': artifact_types.Artifact, 'model': artifact_types.Model, 'dataset': artifact_types.Dataset, 'metrics': artifact_types.Metrics, @@ -413,7 +414,7 @@ def _annotation_to_type_struct(annotation): type_struct = get_canonical_type_name_for_type(annotation) if type_struct: return type_struct - elif type_annotations.is_artifact(annotation): + elif type_annotations.is_artifact_class(annotation): schema_title = annotation.schema_title else: schema_title = str(annotation.__name__) @@ -423,3 +424,8 @@ def _annotation_to_type_struct(annotation): schema_title = str(annotation) type_struct = get_canonical_type_name_for_type(schema_title) return type_struct or schema_title + + +def is_typed_named_tuple_annotation(annotation: Any) -> bool: + return hasattr(annotation, '_fields') and hasattr(annotation, + '__annotations__') diff --git a/sdk/python/requirements.txt b/sdk/python/requirements.txt index 7478c5f6703..3820e811a9c 100644 --- a/sdk/python/requirements.txt +++ b/sdk/python/requirements.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.7 # To update, run: # -# pip-compile +# pip-compile --no-emit-index-url requirements.in # absl-py==1.2.0 # via -r requirements.in @@ -34,7 +34,7 @@ google-api-core==2.8.2 # -r requirements.in # google-cloud-core # google-cloud-storage -google-auth==2.10.0 +google-auth==2.11.0 # via # -r requirements.in # google-api-core @@ -45,7 +45,7 @@ google-cloud-core==2.3.2 # via google-cloud-storage google-cloud-storage==2.5.0 # via -r requirements.in -google-crc32c==1.3.0 +google-crc32c==1.5.0 # via google-resumable-media google-resumable-media==2.3.3 # via google-cloud-storage @@ -53,11 +53,15 @@ googleapis-common-protos==1.56.4 # via google-api-core idna==3.3 # via requests +importlib-metadata==4.12.0 + # via + # click + # jsonschema jsonschema==3.2.0 # via -r requirements.in kfp-pipeline-spec==0.1.16 # via -r requirements.in -kfp-server-api==2.0.0a3 +kfp-server-api==2.0.0a4 # via -r requirements.in kubernetes==23.6.0 # via -r requirements.in @@ -114,19 +118,25 @@ termcolor==1.1.0 # via fire typer==0.6.1 # via -r requirements.in +typing-extensions==4.3.0 ; python_version < "3.9" + # via + # -r requirements.in + # importlib-metadata uritemplate==3.0.1 # via -r requirements.in -urllib3==1.26.11 +urllib3==1.26.12 # via # kfp-server-api # kubernetes # requests -websocket-client==1.3.3 +websocket-client==1.4.0 # via kubernetes wheel==0.37.1 # via strip-hints wrapt==1.14.1 # via deprecated +zipp==3.8.1 + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/sdk/python/test_data/pipelines/pipeline_with_google_artifact_type.yaml b/sdk/python/test_data/pipelines/pipeline_with_google_artifact_type.yaml index 25f40272820..a45bb465a21 100644 --- a/sdk/python/test_data/pipelines/pipeline_with_google_artifact_type.yaml +++ b/sdk/python/test_data/pipelines/pipeline_with_google_artifact_type.yaml @@ -66,8 +66,9 @@ deploymentSpec: ' - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ - \ *\n\ndef model_consumer(model: Input[VertexModel],\n \ - \ dataset: Input[VertexDataset]):\n print('Model')\n print('artifact.type:\ + \ *\nfrom aiplatform import VertexModel\nfrom aiplatform import VertexDataset\n\ + \ndef model_consumer(model: Input[VertexModel],\n dataset:\ + \ Input[VertexDataset]):\n print('Model')\n print('artifact.type:\ \ ', type(model))\n print('artifact.name: ', model.name)\n print('artifact.uri:\ \ ', model.uri)\n print('artifact.metadata: ', model.metadata)\n\n \ \ print('Dataset')\n print('artifact.type: ', type(dataset))\n print('artifact.name:\ @@ -98,9 +99,9 @@ deploymentSpec: ' - "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ - \ *\n\ndef model_producer(model: Output[aiplatform.VertexModel]):\n\n \ - \ assert isinstance(model, aiplatform.VertexModel), type(model)\n with\ - \ open(model.path, 'w') as f:\n f.write('my model')\n\n" + \ *\nimport aiplatform\n\ndef model_producer(model: Output[aiplatform.VertexModel]):\n\ + \n assert isinstance(model, aiplatform.VertexModel), type(model)\n \ + \ with open(model.path, 'w') as f:\n f.write('my model')\n\n" image: python:3.7 pipelineInfo: name: pipeline-with-google-types