Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs(samples): add AutoML image classification sample #923

Merged
merged 22 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8fc413d
Create predict_image_classification_sample.py
andrewferlitsch Dec 21, 2021
6100f76
feat: new sample and test
andrewferlitsch Dec 21, 2021
ed92732
lint: fix wsp
andrewferlitsch Dec 21, 2021
787b460
lint: import order
andrewferlitsch Dec 21, 2021
991b878
lint: fix import
andrewferlitsch Dec 21, 2021
5752cd9
tags: fixed start tag
andrewferlitsch Dec 21, 2021
dda6fb7
Merge branch 'main' into issue_202042079
andrewferlitsch Jan 8, 2022
37b1c78
samples: change tabular to image in sample function name.
andrewferlitsch Mar 17, 2022
e38c2b4
Merge branch 'main' into issue_202042079
andrewferlitsch Mar 17, 2022
102b4c5
samples: replace TF version of reading in binary file with Python ver…
andrewferlitsch Mar 17, 2022
bca85e9
samples: delete tf import, move other imports within region tags
andrewferlitsch Mar 17, 2022
b8cbfa1
Update predict_image_classification_sample.py
andrewferlitsch Mar 17, 2022
14428d4
samples: move imports for lint
andrewferlitsch Mar 17, 2022
9d6a9bb
Merge branch 'main' into issue_202042079
kweinmeister Mar 31, 2022
3a22b8a
Merge branch 'main' into issue_202042079
kweinmeister Apr 11, 2022
5ba43af
Merge branch 'main' into issue_202042079
kweinmeister Apr 22, 2022
291695e
Update predict_image_classification_sample.py
andrewferlitsch Jul 1, 2022
2c8186c
Update predict_image_classification_sample_test.py
andrewferlitsch Jul 1, 2022
a9fdbf9
Merge branch 'main' into issue_202042079
andrewferlitsch Jul 1, 2022
7656a4c
Merge branch 'main' into issue_202042079
rosiezou Aug 2, 2022
e97ca2b
Merge branch 'main' into issue_202042079
nayaknishant Aug 3, 2022
cceff6b
Merge branch 'main' into issue_202042079
rosiezou Aug 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions samples/model-builder/predict_image_classification_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import base64

from typing import List

from google.cloud import aiplatform

import tensorflow as tf


# [START aiplatform_sdk_predict_image_classification_sample]
def predict_tabular_classification_sample(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this read predict_image_classification_sample?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that looks like a typo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it.

project: str,
location: str,
endpoint_name: str,
images: List,
):
'''
Args
project: Your project ID or project number.
location: Region where Endpoint is located. For example, 'us-central1'.
endpoint_name: A fully qualified endpoint name or endpoint ID. Example: "projects/123/locations/us-central1/endpoints/456" or
"456" when project and location are initialized or passed.
images: A list of one or more images to return a prediction for.
'''
aiplatform.init(project=project, location=location)

endpoint = aiplatform.Endpoint(endpoint_name)

instances = []
for image in images:
with tf.io.gfile.GFile(image, "rb") as f:
content = f.read()
instances.append({"content": base64.b64encode(content).decode("utf-8")})

response = endpoint.predict(instances=instances)

for prediction_ in response.predictions:
print(prediction_)


# [END aiplatform_sdk_predict_image_classification_sample]
33 changes: 33 additions & 0 deletions samples/model-builder/predict_image_classification_sample_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import predict_image_classification_sample
import test_constants as constants


def test_predict_image_classification_sample(mock_sdk_init, mock_get_endpoint):

predict_image_classification_sample.predict_image_classification_sample(
project=constants.PROJECT,
location=constants.LOCATION,
endpoint_name=constants.ENDPOINT_NAME,
images=[]
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)

mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME,)