From 8cb483918bdbaeae34935eef2b3cd997c1ae89a3 Mon Sep 17 00:00:00 2001 From: Vinny Senthil Date: Thu, 2 Dec 2021 16:40:10 -0800 Subject: [PATCH] fix: Support multiple instances in custom predict sample (#857) --- .../predict_custom_trained_model_sample.py | 14 ++++++++++---- .../predict_custom_trained_model_sample_test.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/samples/snippets/prediction_service/predict_custom_trained_model_sample.py b/samples/snippets/prediction_service/predict_custom_trained_model_sample.py index d48085945b..5d04fc2400 100644 --- a/samples/snippets/prediction_service/predict_custom_trained_model_sample.py +++ b/samples/snippets/prediction_service/predict_custom_trained_model_sample.py @@ -13,7 +13,7 @@ # limitations under the License. # [START aiplatform_predict_custom_trained_model_sample] -from typing import Dict +from typing import Dict, List, Union from google.cloud import aiplatform from google.protobuf import json_format @@ -23,18 +23,24 @@ def predict_custom_trained_model_sample( project: str, endpoint_id: str, - instance_dict: Dict, + instances: Union[Dict, List[Dict]], location: str = "us-central1", api_endpoint: str = "us-central1-aiplatform.googleapis.com", ): + """ + `instances` can be either single instance of type dict or a list + of instances. + """ # The AI Platform services require regional API endpoints. client_options = {"api_endpoint": api_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. client = aiplatform.gapic.PredictionServiceClient(client_options=client_options) # The format of each instance should conform to the deployed model's prediction input schema. - instance = json_format.ParseDict(instance_dict, Value()) - instances = [instance] + instances = instances if type(instances) == list else [instances] + instances = [ + json_format.ParseDict(instance_dict, Value()) for instance_dict in instances + ] parameters_dict = {} parameters = json_format.ParseDict(parameters_dict, Value()) endpoint = client.endpoint_path( diff --git a/samples/snippets/prediction_service/predict_custom_trained_model_sample_test.py b/samples/snippets/prediction_service/predict_custom_trained_model_sample_test.py index b3cc1ddaf6..9ba69047c5 100644 --- a/samples/snippets/prediction_service/predict_custom_trained_model_sample_test.py +++ b/samples/snippets/prediction_service/predict_custom_trained_model_sample_test.py @@ -32,9 +32,20 @@ def test_ucaip_generated_predict_custom_trained_model_sample(capsys): instance_dict = {"image_bytes": {"b64": encoded_content}, "key": "0"} + # Single instance as a dict predict_custom_trained_model_sample.predict_custom_trained_model_sample( - instance_dict=instance_dict, project=PROJECT_ID, endpoint_id=ENDPOINT_ID + instances=instance_dict, project=PROJECT_ID, endpoint_id=ENDPOINT_ID + ) + + # Multiple instances in a list + predict_custom_trained_model_sample.predict_custom_trained_model_sample( + instances=[instance_dict, instance_dict], + project=PROJECT_ID, + endpoint_id=ENDPOINT_ID, ) out, _ = capsys.readouterr() assert "1.0" in out + + # Two sets of scores for multi-instance, one score for single instance + assert out.count("scores") == 3