Skip to content

Commit

Permalink
support artifact types under google namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
chensun committed Sep 30, 2021
1 parent ab58885 commit b385175
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
"schemaTitle": "system.HTML",
"schemaVersion": "0.0.1"
}
},
"input_i": {
"artifactType": {
"schemaTitle": "google.BQMLModel",
"schemaVersion": "0.0.1"
}
}
},
"parameters": {
Expand Down Expand Up @@ -116,6 +122,12 @@
"schemaTitle": "system.HTML",
"schemaVersion": "0.0.1"
}
},
"output_9": {
"artifactType": {
"schemaTitle": "google.BQMLModel",
"schemaVersion": "0.0.1"
}
}
},
"parameters": {
Expand Down Expand Up @@ -223,6 +235,12 @@
"outputArtifactKey": "output_8",
"producerTask": "upstream"
}
},
"input_i": {
"taskOutputArtifact": {
"outputArtifactKey": "output_9",
"producerTask": "upstream"
}
}
},
"parameters": {
Expand Down Expand Up @@ -286,7 +304,7 @@
}
},
"schemaVersion": "2.0.0",
"sdkVersion": "kfp-1.7.2"
"sdkVersion": "kfp-1.8.3"
},
"runtimeConfig": {
"gcsOutputDirectory": "dummy_root",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import pathlib

from kfp import components
from kfp.v2 import dsl
import kfp.v2.compiler as compiler
from kfp.v2 import compiler, dsl

component_op_1 = components.load_component_from_text("""
name: upstream
Expand All @@ -34,6 +33,7 @@
- {name: output_6, type: Some arbitrary type}
- {name: output_7, type: {GcsPath: {data_type: TSV}}}
- {name: output_8, type: HTML}
- {name: output_9, type: google.BQMLModel}
implementation:
container:
image: gcr.io/image
Expand Down Expand Up @@ -63,6 +63,7 @@
- {name: input_f, type: Some arbitrary type}
- {name: input_g, type: {GcsPath: {data_type: TSV}}}
- {name: input_h, type: HTML}
- {name: input_i, type: google.BQMLModel}
implementation:
container:
image: gcr.io/image
Expand Down Expand Up @@ -95,6 +96,7 @@ def my_pipeline(input1: str, input3: str, input4: str = ''):
input_f=component_1.outputs['output_6'],
input_g=component_1.outputs['output_7'],
input_h=component_1.outputs['output_8'],
input_i=component_1.outputs['output_9'],
)


Expand Down
14 changes: 11 additions & 3 deletions sdk/python/kfp/v2/components/types/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.
"""Utilities for component I/O type mapping."""
import inspect
import re
from typing import Dict, List, Optional, Type, Union
from kfp.components import structures
from kfp.components import type_annotation_utils

from kfp.components import structures, type_annotation_utils
from kfp.pipeline_spec import pipeline_spec_pb2
from kfp.v2.components.types import artifact_types


PARAMETER_TYPES = Union[str, int, float, bool, dict, list]

# ComponentSpec I/O types to DSL ontology artifact classes mapping.
Expand All @@ -33,6 +33,9 @@
'markdown': artifact_types.Markdown,
}

_GOOGLE_TYPES_PATTERN = r'^google.[A-Za-z]+$'
_GOOGLE_TYPES_VERSION = '0.0.1'

# ComponentSpec I/O types to (IR) PipelineTaskSpec I/O types mapping.
# The keys are normalized (lowercased). These are types viewed as Parameters.
# The values are the corresponding IR parameter primitive types.
Expand Down Expand Up @@ -88,6 +91,11 @@ def get_artifact_type_schema(
type."""
artifact_class = artifact_types.Artifact
if isinstance(artifact_class_or_type_name, str):
if re.match(_GOOGLE_TYPES_PATTERN, artifact_class_or_type_name):
return pipeline_spec_pb2.ArtifactTypeSchema(
schema_title=artifact_class_or_type_name,
schema_version=_GOOGLE_TYPES_VERSION,
)
artifact_class = _ARTIFACT_CLASSES_MAPPING.get(
artifact_class_or_type_name.lower(), artifact_types.Artifact)
elif inspect.isclass(artifact_class_or_type_name) and issubclass(
Expand Down
32 changes: 30 additions & 2 deletions sdk/python/kfp/v2/components/types/type_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized

import sys
import unittest
from typing import Any, Dict, List

from absl.testing import parameterized
from kfp.components import structures
from kfp.pipeline_spec import pipeline_spec_pb2 as pb
from kfp.v2.components.types import artifact_types, type_utils
Expand Down Expand Up @@ -49,6 +48,14 @@ class _ArbitraryClass:
pass


class _VertexDummy(artifact_types.Artifact):
TYPE_NAME = 'google.VertexDummy'
VERSION = '0.0.2'

def __init__(self):
super().__init__(uri='uri', name='name', metadata={'dummy': '123'})


class TypeUtilsTest(parameterized.TestCase):

def test_is_parameter_type(self):
Expand Down Expand Up @@ -160,6 +167,27 @@ def test_is_parameter_type(self):
pb.ArtifactTypeSchema(
schema_title='system.Markdown', schema_version='0.0.1')
},
{
'artifact_class_or_type_name':
'some-google-type',
'expected_result':
pb.ArtifactTypeSchema(
schema_title='system.Artifact', schema_version='0.0.1')
},
{
'artifact_class_or_type_name':
'google.VertexModel',
'expected_result':
pb.ArtifactTypeSchema(
schema_title='google.VertexModel', schema_version='0.0.1')
},
{
'artifact_class_or_type_name':
_VertexDummy,
'expected_result':
pb.ArtifactTypeSchema(
schema_title='google.VertexDummy', schema_version='0.0.2')
},
)
def test_get_artifact_type_schema(self, artifact_class_or_type_name,
expected_result):
Expand Down

0 comments on commit b385175

Please sign in to comment.