diff --git a/samples/test/metrics_visualization_v2_test.py b/samples/test/metrics_visualization_v2_test.py index cf12672906a..701ed7c3500 100644 --- a/samples/test/metrics_visualization_v2_test.py +++ b/samples/test/metrics_visualization_v2_test.py @@ -12,15 +12,180 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest +import unittest.mock as mock +from pprint import pprint import kfp +import kfp_server_api + from .metrics_visualization_v2 import metrics_visualization_pipeline -from .util import run_pipeline_func, TestCase +from .util import run_pipeline_func, TestCase, KfpMlmdClient + + +def verify( + run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, + **kwargs +): + t = unittest.TestCase() + t.maxDiff = None # we always want to see full diff + t.assertEqual(run.status, 'Succeeded') + client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config) + tasks = client.get_tasks(argo_workflow_name=argo_workflow_name) + + task_names = [*tasks.keys()] + t.assertCountEqual( + task_names, + ['wine-classification', 'iris-sgdclassifier', 'digit-classification'], + 'task names' + ) + + wine_classification: KfpTask = tasks.get('wine-classification') + iris_sgdclassifier: KfpTask = tasks.get('iris-sgdclassifier') + digit_classification: KfpTask = tasks.get('digit-classification') + + pprint('======= wine classification task =======') + pprint(wine_classification.get_dict()) + pprint('======= iris sgdclassifier task =======') + pprint(iris_sgdclassifier.get_dict()) + pprint('======= digit classification task =======') + pprint(digit_classification.get_dict()) + t.assertEqual( + wine_classification.get_dict(), { + 'inputs': { + 'artifacts': [], + 'parameters': {} + }, + 'name': 'wine-classification', + 'outputs': { + 'artifacts': [{ + 'metadata': { + 'confidenceMetrics': [{ + 'confidenceThreshold': 2.0, + 'falsePositiveRate': 0.0, + 'recall': 0.0 + }, { + 'confidenceThreshold': 1.0, + 'falsePositiveRate': 0.0, + 'recall': 0.33962264150943394 + }, { + 'confidenceThreshold': 0.9, + 'falsePositiveRate': 0.0, + 'recall': 0.6037735849056604 + }, { + 'confidenceThreshold': 0.8, + 'falsePositiveRate': 0.0, + 'recall': 0.8490566037735849 + }, { + 'confidenceThreshold': 0.6, + 'falsePositiveRate': 0.0, + 'recall': 0.8867924528301887 + }, { + 'confidenceThreshold': 0.5, + 'falsePositiveRate': 0.0125, + 'recall': 0.9245283018867925 + }, { + 'confidenceThreshold': 0.4, + 'falsePositiveRate': 0.075, + 'recall': 0.9622641509433962 + }, { + 'confidenceThreshold': 0.3, + 'falsePositiveRate': 0.0875, + 'recall': 1.0 + }, { + 'confidenceThreshold': 0.2, + 'falsePositiveRate': 0.2375, + 'recall': 1.0 + }, { + 'confidenceThreshold': 0.1, + 'falsePositiveRate': 0.475, + 'recall': 1.0 + }, { + 'confidenceThreshold': 0.0, + 'falsePositiveRate': 1.0, + 'recall': 1.0 + }] + }, + 'name': 'metrics', + 'type': 'system.ClassificationMetrics' + }], + 'parameters': {} + }, + 'type': 'kfp.ContainerExecution' + } + ) + t.assertEqual( + iris_sgdclassifier.get_dict(), { + 'inputs': { + 'artifacts': [], + 'parameters': { + 'test_samples_fraction': 0.3 + } + }, + 'name': 'iris-sgdclassifier', + 'outputs': { + 'artifacts': [{ + 'metadata': { + 'confusionMatrix': { + 'annotationSpecs': [{ + 'displayName': 'Setosa' + }, { + 'displayName': 'Versicolour' + }, { + 'displayName': 'Virginica' + }], + 'rows': [{ # these numbers can be random during execution + 'row': [mock.ANY, mock.ANY, mock.ANY] + }, { + 'row': [mock.ANY, mock.ANY, mock.ANY] + }, { + 'row': [mock.ANY, mock.ANY, mock.ANY] + }] + } + }, + 'name': 'metrics', + 'type': 'system.ClassificationMetrics' + }], + 'parameters': {} + }, + 'type': 'kfp.ContainerExecution' + } + ) + rows = iris_sgdclassifier.get_dict( + )['outputs']['artifacts'][0]['metadata']['confusionMatrix']['rows'] + for i, row in enumerate(rows): + for j, item in enumerate(row['row']): + t.assertIsInstance( + item, float, + f'value of confusion matrix row {i}, col {j} is not a number' + ) -def verify(run, run_id: str, **kwargs): - assert run.status == 'Succeeded' + t.assertEqual( + digit_classification.get_dict(), { + 'inputs': { + 'artifacts': [], + 'parameters': {} + }, + 'name': 'digit-classification', + 'outputs': { + 'artifacts': [{ + 'metadata': { + 'accuracy': 92.0 + }, + 'name': 'metrics', + 'type': 'system.Metrics' + }], + 'parameters': {} + }, + 'type': 'kfp.ContainerExecution' + } + ) run_pipeline_func([ - TestCase(pipeline_func=metrics_visualization_pipeline, verify_func=verify, mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE), -]) \ No newline at end of file + TestCase( + pipeline_func=metrics_visualization_pipeline, + verify_func=verify, + mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE + ), +]) diff --git a/samples/test/two_step_test.py b/samples/test/two_step_test.py index 78a0c183119..84bcadabb68 100644 --- a/samples/test/two_step_test.py +++ b/samples/test/two_step_test.py @@ -56,10 +56,11 @@ def verify_tasks(t: unittest.TestCase, tasks: dict): } }, 'outputs': { - 'artifacts': [{'custom_properties': {'name': 'output_dataset_one'}, - 'name': 'output_dataset_one', - 'type': 'system.Dataset' - }], + 'artifacts': [{ + 'metadata': {}, + 'name': 'output_dataset_one', + 'type': 'system.Dataset' + }], 'parameters': { 'output_parameter_one': 1234 } @@ -71,19 +72,21 @@ def verify_tasks(t: unittest.TestCase, tasks: dict): train.get_dict(), { 'name': 'train-op', 'inputs': { - 'artifacts': [{'custom_properties': {'name': 'output_dataset_one'}, - 'name': 'output_dataset_one', - 'type': 'system.Dataset', - }], + 'artifacts': [{ + 'metadata': {}, + 'name': 'output_dataset_one', + 'type': 'system.Dataset', + }], 'parameters': { 'num_steps': 1234 } }, 'outputs': { - 'artifacts': [{'custom_properties': {'name': 'model'}, - 'name': 'model', - 'type': 'system.Model', - }], + 'artifacts': [{ + 'metadata': {}, + 'name': 'model', + 'type': 'system.Model', + }], 'parameters': {} }, 'type': 'kfp.ContainerExecution' @@ -98,8 +101,8 @@ def verify_artifacts(t: unittest.TestCase, tasks: dict, artifact_uri_prefix): def verify( - run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, - **kwargs + run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, + **kwargs ): t = unittest.TestCase() t.maxDiff = None # we always want to see full diff @@ -107,9 +110,10 @@ def verify( tasks = get_tasks(mlmd_connection_config, argo_workflow_name) verify_tasks(t, tasks) + def verify_with_default_pipeline_root( - run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, - **kwargs + run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, + **kwargs ): t = unittest.TestCase() t.maxDiff = None # we always want to see full diff @@ -120,8 +124,8 @@ def verify_with_default_pipeline_root( def verify_with_specific_pipeline_root( - run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, - **kwargs + run: kfp_server_api.ApiRun, mlmd_connection_config, argo_workflow_name: str, + **kwargs ): t = unittest.TestCase() t.maxDiff = None # we always want to see full diff @@ -143,19 +147,22 @@ def verify_with_specific_pipeline_root( mode=kfp.dsl.PipelineExecutionMode.V1_LEGACY ), # Verify default pipeline_root with MinIO - TestCase(pipeline_func=two_step_pipeline, - verify_func=verify_with_default_pipeline_root, - mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE, - arguments={ - kfp.dsl.ROOT_PARAMETER_NAME: ''}, - ), + TestCase( + pipeline_func=two_step_pipeline, + verify_func=verify_with_default_pipeline_root, + mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE, + arguments={kfp.dsl.ROOT_PARAMETER_NAME: ''}, + ), # Verify overriding pipeline root to MinIO - TestCase(pipeline_func=two_step_pipeline, - verify_func=verify_with_specific_pipeline_root, - mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE, - arguments={ - kfp.dsl.ROOT_PARAMETER_NAME: 'minio://mlpipeline/override/artifacts' }, - ) + TestCase( + pipeline_func=two_step_pipeline, + verify_func=verify_with_specific_pipeline_root, + mode=kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE, + arguments={ + kfp.dsl.ROOT_PARAMETER_NAME: + 'minio://mlpipeline/override/artifacts' + }, + ) ]) # %% diff --git a/samples/test/util.py b/samples/test/util.py index 34ec4099a5d..4e8dda86606 100644 --- a/samples/test/util.py +++ b/samples/test/util.py @@ -20,6 +20,7 @@ from dataclasses import dataclass, asdict from pprint import pprint from typing import Dict, List, Callable, Optional +from google.protobuf.json_format import MessageToDict import kfp import kfp_server_api @@ -222,34 +223,28 @@ class KfpArtifact: name: str uri: str type: str - custom_properties: dict + metadata: dict @classmethod def new( cls, mlmd_artifact: metadata_store_pb2.Artifact, mlmd_artifact_type: metadata_store_pb2.ArtifactType ): - custom_properties = {} - for k, v in mlmd_artifact.custom_properties.items(): - raw_value = None - if v.string_value: - raw_value = v.string_value - if v.int_value: - raw_value = v.int_value - if v.double_value: - raw_value = v.double_value - custom_properties[k] = raw_value artifact_name = '' - if mlmd_artifact.name != '': + if mlmd_artifact.name: artifact_name = mlmd_artifact.name else: - if 'name' in custom_properties.keys(): - artifact_name = custom_properties['name'] + if 'name' in mlmd_artifact.custom_properties.keys(): + artifact_name = mlmd_artifact.custom_properties['name' + ].string_value + metadata = MessageToDict( + mlmd_artifact.custom_properties.get('metadata').struct_value + ) return cls( name=artifact_name, type=mlmd_artifact_type.name, uri=mlmd_artifact.uri, - custom_properties=custom_properties + metadata=metadata ) diff --git a/v2/component/runtime_info.go b/v2/component/runtime_info.go index e2a5e4fd2cc..bfce40cb13c 100644 --- a/v2/component/runtime_info.go +++ b/v2/component/runtime_info.go @@ -121,7 +121,7 @@ func parseRuntimeInfo(jsonEncoded string) (*runtimeInfo, error) { func pipelineSpecValueToMLMDValue(v *pipeline_spec.Value) (*pb.Value, error) { switch t := v.Value.(type) { case *pipeline_spec.Value_StringValue: - return &pb.Value{Value: &pb.Value_StringValue{StringValue: v.GetStringValue()}}, nil + return stringToMLMDValue(v.GetStringValue()), nil case *pipeline_spec.Value_DoubleValue: return &pb.Value{Value: &pb.Value_DoubleValue{DoubleValue: v.GetDoubleValue()}}, nil case *pipeline_spec.Value_IntValue: @@ -141,7 +141,7 @@ func structValueToMLMDValue(v *structpb.Value) (*pb.Value, error) { switch t := v.Kind.(type) { case *structpb.Value_StringValue: - return &pb.Value{Value: &pb.Value_StringValue{StringValue: v.GetStringValue()}}, nil + return stringToMLMDValue(v.GetStringValue()), nil case *structpb.Value_NumberValue: return &pb.Value{Value: &pb.Value_DoubleValue{DoubleValue: v.GetNumberValue()}}, nil case *structpb.Value_BoolValue: @@ -192,15 +192,10 @@ func toMLMDArtifact(runtimeArtifact *pipeline_spec.RuntimeArtifact) (*pb.Artifac artifact.CustomProperties[k] = value } - if runtimeArtifact.Metadata != nil { - for k, v := range runtimeArtifact.Metadata.Fields { - value, err := structValueToMLMDValue(v) - if err != nil { - return nil, errorF(err) - } - artifact.CustomProperties[k] = value - } - } + artifact.CustomProperties["name"] = stringToMLMDValue(runtimeArtifact.Name) + artifact.CustomProperties["metadata"] = &pb.Value{Value: &pb.Value_StructValue{ + StructValue: runtimeArtifact.Metadata, + }} return artifact, nil } @@ -354,9 +349,8 @@ func (r *runtimeInfo) generateExecutorInput(genOutputURI generateOutputURI, outp } if strings.HasPrefix(uri, "s3://") { s3Region := os.Getenv("AWS_REGION") - rta.Metadata.Fields["s3_region"] = &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: s3Region}} + rta.Metadata.Fields["s3_region"] = stringToStructValue(s3Region) } - rta.Metadata.Fields["name"] = &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: name}} if err := setRuntimeArtifactType(rta, oa.InstanceSchema, oa.SchemaTitle); err != nil { return nil, fmt.Errorf("failed to generate output RuntimeArtifact: %w", err) @@ -371,3 +365,11 @@ func (r *runtimeInfo) generateExecutorInput(genOutputURI generateOutputURI, outp Outputs: outputs, }, nil } + +func stringToStructValue(v string) *structpb.Value { + return &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: v}} +} + +func stringToMLMDValue(v string) *pb.Value { + return &pb.Value{Value: &pb.Value_StringValue{StringValue: v}} +}