Skip to content

Commit

Permalink
chore(launcher): move artifact metadata to metadata struct field. Fixes
Browse files Browse the repository at this point in the history
#5788 (#5793)

* feat(launcher): move component output metadata to metadata field

* add MLMD validation for metrics_visualization_v2

* clean up

* address comment

* fix
  • Loading branch information
Bobgy authored Jun 8, 2021
1 parent 941879d commit a75a3c7
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 63 deletions.
175 changes: 170 additions & 5 deletions samples/test/metrics_visualization_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,180 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest.mock as mock
from pprint import pprint
import kfp
import kfp_server_api

from .metrics_visualization_v2 import metrics_visualization_pipeline
from .util import run_pipeline_func, TestCase
from .util import run_pipeline_func, TestCase, KfpMlmdClient


def verify(
run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str,
**kwargs
):
t = unittest.TestCase()
t.maxDiff = None # we always want to see full diff
t.assertEqual(run.status, 'Succeeded')
client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config)
tasks = client.get_tasks(argo_workflow_name=argo_workflow_name)

task_names = [*tasks.keys()]
t.assertCountEqual(
task_names,
['wine-classification', 'iris-sgdclassifier', 'digit-classification'],
'task names'
)

wine_classification: KfpTask = tasks.get('wine-classification')
iris_sgdclassifier: KfpTask = tasks.get('iris-sgdclassifier')
digit_classification: KfpTask = tasks.get('digit-classification')

pprint('======= wine classification task =======')
pprint(wine_classification.get_dict())
pprint('======= iris sgdclassifier task =======')
pprint(iris_sgdclassifier.get_dict())
pprint('======= digit classification task =======')
pprint(digit_classification.get_dict())

t.assertEqual(
wine_classification.get_dict(), {
'inputs': {
'artifacts': [],
'parameters': {}
},
'name': 'wine-classification',
'outputs': {
'artifacts': [{
'metadata': {
'confidenceMetrics': [{
'confidenceThreshold': 2.0,
'falsePositiveRate': 0.0,
'recall': 0.0
}, {
'confidenceThreshold': 1.0,
'falsePositiveRate': 0.0,
'recall': 0.33962264150943394
}, {
'confidenceThreshold': 0.9,
'falsePositiveRate': 0.0,
'recall': 0.6037735849056604
}, {
'confidenceThreshold': 0.8,
'falsePositiveRate': 0.0,
'recall': 0.8490566037735849
}, {
'confidenceThreshold': 0.6,
'falsePositiveRate': 0.0,
'recall': 0.8867924528301887
}, {
'confidenceThreshold': 0.5,
'falsePositiveRate': 0.0125,
'recall': 0.9245283018867925
}, {
'confidenceThreshold': 0.4,
'falsePositiveRate': 0.075,
'recall': 0.9622641509433962
}, {
'confidenceThreshold': 0.3,
'falsePositiveRate': 0.0875,
'recall': 1.0
}, {
'confidenceThreshold': 0.2,
'falsePositiveRate': 0.2375,
'recall': 1.0
}, {
'confidenceThreshold': 0.1,
'falsePositiveRate': 0.475,
'recall': 1.0
}, {
'confidenceThreshold': 0.0,
'falsePositiveRate': 1.0,
'recall': 1.0
}]
},
'name': 'metrics',
'type': 'system.ClassificationMetrics'
}],
'parameters': {}
},
'type': 'kfp.ContainerExecution'
}
)
t.assertEqual(
iris_sgdclassifier.get_dict(), {
'inputs': {
'artifacts': [],
'parameters': {
'test_samples_fraction': 0.3
}
},
'name': 'iris-sgdclassifier',
'outputs': {
'artifacts': [{
'metadata': {
'confusionMatrix': {
'annotationSpecs': [{
'displayName': 'Setosa'
}, {
'displayName': 'Versicolour'
}, {
'displayName': 'Virginica'
}],
'rows': [{ # these numbers can be random during execution
'row': [mock.ANY, mock.ANY, mock.ANY]
}, {
'row': [mock.ANY, mock.ANY, mock.ANY]
}, {
'row': [mock.ANY, mock.ANY, mock.ANY]
}]
}
},
'name': 'metrics',
'type': 'system.ClassificationMetrics'
}],
'parameters': {}
},
'type': 'kfp.ContainerExecution'
}
)
rows = iris_sgdclassifier.get_dict(
)['outputs']['artifacts'][0]['metadata']['confusionMatrix']['rows']
for i, row in enumerate(rows):
for j, item in enumerate(row['row']):
t.assertIsInstance(
item, float,
f'value of confusion matrix row {i}, col {j} is not a number'
)

def verify(run, run_id: str, **kwargs):
assert run.status == 'Succeeded'
t.assertEqual(
digit_classification.get_dict(), {
'inputs': {
'artifacts': [],
'parameters': {}
},
'name': 'digit-classification',
'outputs': {
'artifacts': [{
'metadata': {
'accuracy': 92.0
},
'name': 'metrics',
'type': 'system.Metrics'
}],
'parameters': {}
},
'type': 'kfp.ContainerExecution'
}
)


run_pipeline_func([
TestCase(pipeline_func=metrics_visualization_pipeline, verify_func=verify, mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE),
])
TestCase(
pipeline_func=metrics_visualization_pipeline,
verify_func=verify,
mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE
),
])
67 changes: 37 additions & 30 deletions samples/test/two_step_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ def verify_tasks(t: unittest.TestCase, tasks: dict):
}
},
'outputs': {
'artifacts': [{'custom_properties': {'name': 'output_dataset_one'},
'name': 'output_dataset_one',
'type': 'system.Dataset'
}],
'artifacts': [{
'metadata': {},
'name': 'output_dataset_one',
'type': 'system.Dataset'
}],
'parameters': {
'output_parameter_one': 1234
}
Expand All @@ -71,19 +72,21 @@ def verify_tasks(t: unittest.TestCase, tasks: dict):
train.get_dict(), {
'name': 'train-op',
'inputs': {
'artifacts': [{'custom_properties': {'name': 'output_dataset_one'},
'name': 'output_dataset_one',
'type': 'system.Dataset',
}],
'artifacts': [{
'metadata': {},
'name': 'output_dataset_one',
'type': 'system.Dataset',
}],
'parameters': {
'num_steps': 1234
}
},
'outputs': {
'artifacts': [{'custom_properties': {'name': 'model'},
'name': 'model',
'type': 'system.Model',
}],
'artifacts': [{
'metadata': {},
'name': 'model',
'type': 'system.Model',
}],
'parameters': {}
},
'type': 'kfp.ContainerExecution'
Expand All @@ -98,18 +101,19 @@ def verify_artifacts(t: unittest.TestCase, tasks: dict, artifact_uri_prefix):


def verify(
run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str,
**kwargs
run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str,
**kwargs
):
t = unittest.TestCase()
t.maxDiff = None # we always want to see full diff
t.assertEqual(run.status, 'Succeeded')
tasks = get_tasks(mlmd_connection_config, argo_workflow_name)
verify_tasks(t, tasks)


def verify_with_default_pipeline_root(
run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str,
**kwargs
run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str,
**kwargs
):
t = unittest.TestCase()
t.maxDiff = None # we always want to see full diff
Expand All @@ -120,8 +124,8 @@ def verify_with_default_pipeline_root(


def verify_with_specific_pipeline_root(
run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str,
**kwargs
run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str,
**kwargs
):
t = unittest.TestCase()
t.maxDiff = None # we always want to see full diff
Expand All @@ -143,19 +147,22 @@ def verify_with_specific_pipeline_root(
mode=kfp.dsl.PipelineExecutionMode.V1_LEGACY
),
# Verify default pipeline_root with MinIO
TestCase(pipeline_func=two_step_pipeline,
verify_func=verify_with_default_pipeline_root,
mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE,
arguments={
kfp.dsl.ROOT_PARAMETER_NAME: ''},
),
TestCase(
pipeline_func=two_step_pipeline,
verify_func=verify_with_default_pipeline_root,
mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE,
arguments={kfp.dsl.ROOT_PARAMETER_NAME: ''},
),
# Verify overriding pipeline root to MinIO
TestCase(pipeline_func=two_step_pipeline,
verify_func=verify_with_specific_pipeline_root,
mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE,
arguments={
kfp.dsl.ROOT_PARAMETER_NAME: 'minio://mlpipeline/override/artifacts' },
)
TestCase(
pipeline_func=two_step_pipeline,
verify_func=verify_with_specific_pipeline_root,
mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE,
arguments={
kfp.dsl.ROOT_PARAMETER_NAME:
'minio://mlpipeline/override/artifacts'
},
)
])

# %%
25 changes: 10 additions & 15 deletions samples/test/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dataclasses import dataclass, asdict
from pprint import pprint
from typing import Dict, List, Callable, Optional
from google.protobuf.json_format import MessageToDict

import kfp
import kfp_server_api
Expand Down Expand Up @@ -222,34 +223,28 @@ class KfpArtifact:
name: str
uri: str
type: str
custom_properties: dict
metadata: dict

@classmethod
def new(
cls, mlmd_artifact: metadata_store_pb2.Artifact,
mlmd_artifact_type: metadata_store_pb2.ArtifactType
):
custom_properties = {}
for k, v in mlmd_artifact.custom_properties.items():
raw_value = None
if v.string_value:
raw_value = v.string_value
if v.int_value:
raw_value = v.int_value
if v.double_value:
raw_value = v.double_value
custom_properties[k] = raw_value
artifact_name = ''
if mlmd_artifact.name != '':
if mlmd_artifact.name:
artifact_name = mlmd_artifact.name
else:
if 'name' in custom_properties.keys():
artifact_name = custom_properties['name']
if 'name' in mlmd_artifact.custom_properties.keys():
artifact_name = mlmd_artifact.custom_properties['name'
].string_value
metadata = MessageToDict(
mlmd_artifact.custom_properties.get('metadata').struct_value
)
return cls(
name=artifact_name,
type=mlmd_artifact_type.name,
uri=mlmd_artifact.uri,
custom_properties=custom_properties
metadata=metadata
)


Expand Down
Loading

0 comments on commit a75a3c7

Please sign in to comment.