diff --git a/language/automl/README.md b/language/automl/README.md new file mode 100644 index 00000000000..97e17bee204 --- /dev/null +++ b/language/automl/README.md @@ -0,0 +1,88 @@ +# AutoML Sample + + +Open in Cloud Shell + +[Google Cloud Natural Language API][language] provides feature detection for images. +This API is part of the larger collection of Cloud Machine Learning APIs. + +This sample Java application demonstrates how to access the Cloud Natural Language AutoML API +using the [Google Cloud Client Library for Java][google-cloud-java]. + +[language]: https://cloud.google.com/language/docs/ +[google-cloud-java]: https://github.com/GoogleCloudPlatform/google-cloud-java + +## Set the environment variables + +PROJECT_ID = [Id of the project] +REGION_NAME = [Region name] + +## Build the sample + +Install [Maven](http://maven.apache.org/). + +Build your project with: + +``` +mvn clean package +``` + +### Dataset API + +#### Create a new dataset +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.DatasetApi" -Dexec.args="create_dataset test_dataset" +``` + +#### List datasets +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.DatasetApi" -Dexec.args="list_datasets" +``` + +#### Get dataset +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.DatasetApi" -Dexec.args="get_dataset [dataset-id]" +``` + +#### Import data +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.DatasetApi" -Dexec.args="import_data gs://java-docs-samples-testing/happiness.csv" +``` + +### Model API + +#### Create Model +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.ModelApi" -Dexec.args="create_model test_model" +``` + +#### List Models +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.ModelApi" -Dexec.args="list_models" +``` + +#### Get Model +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.ModelApi" -Dexec.args="get_model [model-id]" +``` + +#### List Model Evaluations +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.ModelApi" -Dexec.args="list_model_evaluation [model-id]" +``` + +#### Get Model Evaluation +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.ModelApi" -Dexec.args="get_model_evaluation [model-id] [model-evaluation-id]" +``` + +#### Delete Model +``` +mvn exec:java-Dexec.mainClass="com.google.cloud.language.samples.ModelApi" -Dexec.args="delete_model [model-id]" +``` +### Predict API + +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.language.samples.PredictApi" -Dexec.args="predict [model-id] ./resources/input.txt" +``` + diff --git a/language/automl/pom.xml b/language/automl/pom.xml new file mode 100644 index 00000000000..66f481dcf40 --- /dev/null +++ b/language/automl/pom.xml @@ -0,0 +1,154 @@ + + + 4.0.0 + com.example.vision + language-automl + jar + + + + com.google.cloud.samples + shared-configuration + 1.0.9 + + + + 1.8 + 1.8 + UTF-8 + + + + + + com.google.cloud + google-cloud-automl + 0.55.1-beta + + + net.sourceforge.argparse4j + argparse4j + 0.8.1 + + + + + + junit + junit + 4.12 + test + + + + com.google.truth + truth + 0.41 + test + + + + + + DatasetApi + + + DatasetApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.language.samples.DatasetApi + false + + + + + + + ModelApi + + + ModelApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.language.samples.ModelApi + false + + + + + + + PredictApi + + + PredictApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.language.samples.PredictionApi + false + + + + + + + diff --git a/language/automl/resources/input.txt b/language/automl/resources/input.txt new file mode 100644 index 00000000000..a711059ba58 --- /dev/null +++ b/language/automl/resources/input.txt @@ -0,0 +1 @@ +creamy, full-flavored, nutty, sweet \ No newline at end of file diff --git a/language/automl/src/main/java/com/google/cloud/language/samples/DatasetApi.java b/language/automl/src/main/java/com/google/cloud/language/samples/DatasetApi.java new file mode 100644 index 00000000000..d7aa610f589 --- /dev/null +++ b/language/automl/src/main/java/com/google/cloud/language/samples/DatasetApi.java @@ -0,0 +1,349 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.language.samples; + +// Imports the Google Cloud client library +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationType; +import com.google.cloud.automl.v1beta1.Dataset; +import com.google.cloud.automl.v1beta1.DatasetName; +import com.google.cloud.automl.v1beta1.GcsDestination; +import com.google.cloud.automl.v1beta1.GcsSource; +import com.google.cloud.automl.v1beta1.InputConfig; +import com.google.cloud.automl.v1beta1.ListDatasetsRequest; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.OutputConfig; +import com.google.cloud.automl.v1beta1.TextClassificationDatasetMetadata; +import com.google.protobuf.Empty; + +import java.io.IOException; +import java.io.PrintStream; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Natural Language API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.DatasetAPI' -Dexec.args='create_dataset + * test_dataset' + */ +public class DatasetApi { + + // [START automl_natural_language_create_dataset] + /** + * Demonstrates using the AutoML client to create a dataset + * + * @param projectId the Google Cloud Project ID. + * @param computeRegion the Region name. (e.g., "us-central1") + * @param datasetName the name of the dataset to be created. + * @param multiLabel the type of classification problem. Set to FALSE by default. False - + * MULTICLASS , True - MULTILABEL + * @throws IOException on Input/Output errors. + */ + public static void createDataset( + String projectId, String computeRegion, String datasetName, Boolean multiLabel) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Classification type assigned based on multilabel value. + ClassificationType classificationType = + multiLabel ? ClassificationType.MULTILABEL : ClassificationType.MULTICLASS; + + // Specify the text classification type for the dataset. + TextClassificationDatasetMetadata textClassificationDatasetMetadata = + TextClassificationDatasetMetadata.newBuilder() + .setClassificationType(classificationType) + .build(); + + // Set dataset name and dataset metadata. + Dataset myDataset = + Dataset.newBuilder() + .setDisplayName(datasetName) + .setTextClassificationDatasetMetadata(textClassificationDatasetMetadata) + .build(); + + // Create a dataset with the dataset metadata in the region. + Dataset dataset = client.createDataset(projectLocation, myDataset); + + // Display the dataset information. + System.out.println(String.format("Dataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Text classification dataset metadata:"); + System.out.print(String.format("\t%s", dataset.getTextClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + // [END automl_natural_language_create_dataset] + + // [START automl_natural_language_list_datasets] + /** + * Demonstrates using the AutoML client to list all datasets. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param filter the Filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listDatasets(String projectId, String computeRegion, String filter) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Build the List datasets request + ListDatasetsRequest request = + ListDatasetsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + // List all the datasets available in the region by applying filter. + System.out.println("List of datasets:"); + for (Dataset dataset : client.listDatasets(request).iterateAll()) { + // Display the dataset information. + System.out.println(String.format("\nDataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Text classification dataset metadata:"); + System.out.print(String.format("\t%s", dataset.getTextClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + } + // [END automl_natural_language_list_datasets] + + // [START automl_natural_language_get_dataset] + /** + * Demonstrates using the AutoML client to get a dataset by ID. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param datasetId the Id of the dataset. + * @throws IOException on Input/Output errors. + */ + public static void getDataset(String projectId, String computeRegion, String datasetId) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Get all the information about a given dataset. + Dataset dataset = client.getDataset(datasetFullId); + + // Display the dataset information. + System.out.println(String.format("Dataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Text classification dataset metadata:"); + System.out.print(String.format("\t%s", dataset.getTextClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + // [END automl_natural_language_get_dataset] + + // [START automl_natural_language_import_data] + /** + * Import labeled items. + * + * @param projectId - Id of the project. + * @param computeRegion - Region name. + * @param datasetId - Id of the dataset into which the training content are to be imported. + * @param path - Google Cloud Storage URIs. Target files must be in AutoML Natural Language CSV + * format. + * @throws Exception on AutoML Client errors + */ + public static void importData( + String projectId, String computeRegion, String datasetId, String path) throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + + // Get multiple training data files to be imported + String[] inputUris = path.split(","); + for (String inputUri : inputUris) { + gcsSource.addInputUris(inputUri); + } + + // Import data from the input URI + InputConfig inputConfig = InputConfig.newBuilder().setGcsSource(gcsSource).build(); + System.out.println("Processing import..."); + + Empty response = client.importDataAsync(datasetFullId, inputConfig).get(); + System.out.println(String.format("Dataset imported. %s", response)); + } + // [END automl_natural_language_import_data] + + // [START automl_natural_language_export_data] + /** + * Demonstrates using the AutoML client to export a dataset to a Google Cloud Storage bucket. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param datasetId the Id of the dataset. + * @param gcsUri the Destination URI (Google Cloud Storage) + * @throws Exception on AutoML Client errors + */ + public static void exportData( + String projectId, String computeRegion, String datasetId, String gcsUri) throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Set the output URI. + GcsDestination gcsDestination = GcsDestination.newBuilder().setOutputUriPrefix(gcsUri).build(); + + // Export the data to the output URI. + OutputConfig outputConfig = OutputConfig.newBuilder().setGcsDestination(gcsDestination).build(); + System.out.println(String.format("Processing export...")); + + Empty response = client.exportDataAsync(datasetFullId, outputConfig).get(); + System.out.println(String.format("Dataset exported. %s", response)); + } + // [END automl_natural_language_export_data] + + // [START automl_natural_language_delete_dataset] + /** + * Delete a dataset. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param datasetId the Id of the dataset. + * @throws Exception on AutoML Client errors + */ + public static void deleteDataset(String projectId, String computeRegion, String datasetId) + throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Delete a dataset. + Empty response = client.deleteDatasetAsync(datasetFullId).get(); + + System.out.println(String.format("Dataset deleted. %s", response)); + } + // [END automl_natural_language_delete_dataset] + + public static void main(String[] args) throws Exception { + DatasetApi datasetApi = new DatasetApi(); + datasetApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws Exception { + ArgumentParser parser = + ArgumentParsers.newFor("DatasetApi") + .build() + .defaultHelp(true) + .description("Dataset API operations."); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser createDatasetParser = subparsers.addParser("create_dataset"); + createDatasetParser.addArgument("datasetName"); + createDatasetParser + .addArgument("multiLabel") + .nargs("?") + .type(Boolean.class) + .choices(Boolean.FALSE, Boolean.TRUE) + .setDefault("False"); + + Subparser listDatasetsParser = subparsers.addParser("list_datasets"); + listDatasetsParser + .addArgument("filter") + .nargs("?") + .setDefault("textClassificationDatasetMetadata:*"); + + Subparser getDatasetParser = subparsers.addParser("get_dataset"); + getDatasetParser.addArgument("datasetId"); + + Subparser importDataParser = subparsers.addParser("import_data"); + importDataParser.addArgument("datasetId"); + importDataParser.addArgument("path"); + + Subparser exportDataParser = subparsers.addParser("export_data"); + exportDataParser.addArgument("datasetId"); + exportDataParser.addArgument("outputUri"); + + Subparser deleteDatasetParser = subparsers.addParser("delete_dataset"); + deleteDatasetParser.addArgument("datasetId"); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + + if (ns.get("command").equals("create_dataset")) { + createDataset( + projectId, computeRegion, ns.getString("datasetName"), ns.getBoolean("multiLabel")); + } + if (ns.get("command").equals("list_datasets")) { + listDatasets(projectId, computeRegion, ns.getString("filter")); + } + if (ns.get("command").equals("get_dataset")) { + getDataset(projectId, computeRegion, ns.getString("datasetId")); + } + if (ns.get("command").equals("import_data")) { + importData(projectId, computeRegion, ns.getString("datasetId"), ns.getString("path")); + } + if (ns.get("command").equals("export_data")) { + exportData(projectId, computeRegion, ns.getString("datasetId"), ns.getString("outputUri")); + } + if (ns.get("command").equals("delete_dataset")) { + deleteDataset(projectId, computeRegion, ns.getString("datasetId")); + } + + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/language/automl/src/main/java/com/google/cloud/language/samples/ModelApi.java b/language/automl/src/main/java/com/google/cloud/language/samples/ModelApi.java new file mode 100644 index 00000000000..16b4841f290 --- /dev/null +++ b/language/automl/src/main/java/com/google/cloud/language/samples/ModelApi.java @@ -0,0 +1,428 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.language.samples; + +// Imports the Google Cloud client library +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationEvaluationMetrics; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationEvaluationMetrics.ConfidenceMetricsEntry; +import com.google.cloud.automl.v1beta1.ListModelEvaluationsRequest; +import com.google.cloud.automl.v1beta1.ListModelsRequest; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.Model; +import com.google.cloud.automl.v1beta1.ModelEvaluation; +import com.google.cloud.automl.v1beta1.ModelEvaluationName; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; + +import com.google.cloud.automl.v1beta1.TextClassificationModelMetadata; +import com.google.longrunning.Operation; +import com.google.protobuf.Empty; + +import java.io.IOException; +import java.io.PrintStream; +import java.util.List; +import java.util.concurrent.ExecutionException; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Natural Language API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.ModelApi' -Dexec.args='create_model + * [datasetId] test_model' + */ +public class ModelApi { + + // [START automl_natural_language_create_model] + /** + * Demonstrates using the AutoML client to create a model. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param dataSetId the Id of the dataset to which model is created. + * @param modelName the Name of the model. + * @throws Exception on AutoML Client errors + */ + public static void createModel( + String projectId, String computeRegion, String dataSetId, String modelName) + throws IOException, InterruptedException, ExecutionException { + + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Set model meta data + TextClassificationModelMetadata textClassificationModelMetadata = + TextClassificationModelMetadata.newBuilder().build(); + + // Set model name, dataset and metadata. + Model myModel = + Model.newBuilder() + .setDisplayName(modelName) + .setDatasetId(dataSetId) + .setTextClassificationModelMetadata(textClassificationModelMetadata) + .build(); + + // Create a model with the model metadata in the region. + OperationFuture response = + client.createModelAsync(projectLocation, myModel); + + System.out.println( + String.format("Training operation name: %s", response.getInitialFuture().get().getName())); + System.out.println("Training started..."); + } + // [END automl_natural_language_create_model] + + // [START automl_natural_language_get_operation_status] + /** + * Demonstrates using the AutoML client to get operation status. + * + * @param operationFullId the complete name of a operation. For example, the name of your + * operation is projects/[projectId]/locations/us-central1/operations/[operationId]. + * @throws IOException on Input/Output errors. + */ + public static void getOperationStatus(String operationFullId) throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the latest state of a long-running operation. + Operation response = client.getOperationsClient().getOperation(operationFullId); + + System.out.println(String.format("Operation status: %s", response)); + } + // [END automl_natural_language_get_operation_status] + + // [START automl_natural_language_list_models] + /** + * Demonstrates using the AutoML client to list all models. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param filter the filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listModels(String projectId, String computeRegion, String filter) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Create list models request. + ListModelsRequest listModlesRequest = + ListModelsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + System.out.println("List of models:"); + for (Model model : client.listModels(listModlesRequest).iterateAll()) { + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } + } + // [END automl_natural_language_list_models] + + // [START automl_natural_language_get_model] + /** + * Demonstrates using the AutoML client to get model details. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @throws IOException on AutoML Client errors + */ + public static void getModel(String projectId, String computeRegion, String modelId) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Get complete detail of the model. + Model model = client.getModel(modelFullId); + + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } + // END automl_natural_language_get_model] + + // [START automl_natural_language_list_model_evaluations] + /** + * Demonstrates using the AutoML client to list model evaluations. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param filter the Filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listModelEvaluations( + String projectId, String computeRegion, String modelId, String filter) throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Create list model evaluations request. + ListModelEvaluationsRequest modelEvaluationsRequest = + ListModelEvaluationsRequest.newBuilder() + .setParent(modelFullId.toString()) + .setFilter(filter) + .build(); + + // List all the model evaluations in the model by applying filter. + for (ModelEvaluation element : + client.listModelEvaluations(modelEvaluationsRequest).iterateAll()) { + System.out.println(element); + } + } + // [END automl_natural_language_list_model_evaluations] + + // [START automl_natural_language_get_model_evaluation] + /** + * Demonstrates using the AutoML client to get model evaluations. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param modelEvaluationId the Id of your model evaluation. + * @throws IOException on Input/Output errors. + */ + public static void getModelEvaluation( + String projectId, String computeRegion, String modelId, String modelEvaluationId) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model evaluation. + ModelEvaluationName modelEvaluationFullId = + ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); + + // Get complete detail of the model evaluation. + ModelEvaluation response = client.getModelEvaluation(modelEvaluationFullId); + + System.out.println(response); + } + // [END automl_natural_language_get_model_evaluation] + + // [START automl_natural_language_display_evaluation] + /** + * Demonstrates using the AutoML client to display model evaluation. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param filter the Filter expression. + * @throws IOException on Input/Output errors. + */ + public static void displayEvaluation( + String projectId, String computeRegion, String modelId, String filter) throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // List all the model evaluations in the model by applying. + ListModelEvaluationsRequest modelEvaluationsrequest = + ListModelEvaluationsRequest.newBuilder() + .setParent(modelFullId.toString()) + .setFilter(filter) + .build(); + + // Iterate through the results. + String modelEvaluationId = ""; + for (ModelEvaluation element : + client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + if (element.getAnnotationSpecId() != null) { + modelEvaluationId = element.getName().split("/")[element.getName().split("/").length - 1]; + } + } + + // Resource name for the model evaluation. + ModelEvaluationName modelEvaluationFullId = + ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); + + // Get a model evaluation. + ModelEvaluation modelEvaluation = client.getModelEvaluation(modelEvaluationFullId); + + ClassificationEvaluationMetrics classMetrics = + modelEvaluation.getClassificationEvaluationMetrics(); + List confidenceMetricsEntries = + classMetrics.getConfidenceMetricsEntryList(); + + // Showing model score based on threshold of 0.5 + for (ConfidenceMetricsEntry confidenceMetricsEntry : confidenceMetricsEntries) { + if (confidenceMetricsEntry.getConfidenceThreshold() == 0.5) { + System.out.println("Precision and recall are based on a score threshold of 0.5"); + System.out.println( + String.format("Model Precision: %.2f ", confidenceMetricsEntry.getPrecision() * 100) + + '%'); + System.out.println( + String.format("Model Recall: %.2f ", confidenceMetricsEntry.getRecall() * 100) + '%'); + System.out.println( + String.format("Model F1 Score: %.2f ", confidenceMetricsEntry.getF1Score() * 100) + + '%'); + System.out.println( + String.format( + "Model Precision@1: %.2f ", confidenceMetricsEntry.getPrecisionAt1() * 100) + + '%'); + System.out.println( + String.format("Model Recall@1: %.2f ", confidenceMetricsEntry.getRecallAt1() * 100) + + '%'); + System.out.println( + String.format("Model F1 Score@1: %.2f ", confidenceMetricsEntry.getF1ScoreAt1() * 100) + + '%'); + } + } + } + // [END automl_natural_language_display_evaluation] + + // [START automl_natural_language_delete_model] + /** + * Demonstrates using the AutoML client to delete a model. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @throws Exception on AutoML Client errors + */ + public static void deleteModel(String projectId, String computeRegion, String modelId) + throws InterruptedException, ExecutionException, IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Delete a model. + Empty response = client.deleteModelAsync(modelFullId).get(); + + System.out.println("Model deletion started..."); + } + // [END automl_natural_language_delete_model] + + public static void main(String[] args) throws Exception { + ModelApi modelApi = new ModelApi(); + modelApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws Exception { + ArgumentParser parser = + ArgumentParsers.newFor("ModelApi") + .build() + .defaultHelp(true) + .description("Model API operations."); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser createModelParser = subparsers.addParser("create_model"); + createModelParser.addArgument("datasetId"); + createModelParser.addArgument("modelName"); + + Subparser listModelsParser = subparsers.addParser("list_models"); + listModelsParser + .addArgument("filter") + .nargs("?") + .setDefault("textClassificationModelMetadata:*"); + + Subparser getModelParser = subparsers.addParser("get_model"); + getModelParser.addArgument("modelId"); + + Subparser listModelEvaluationsParser = subparsers.addParser("list_model_evaluations"); + listModelEvaluationsParser.addArgument("modelId"); + listModelEvaluationsParser.addArgument("filter").nargs("?").setDefault(""); + + Subparser getModelEvaluationParser = subparsers.addParser("get_model_evaluation"); + getModelEvaluationParser.addArgument("modelId"); + getModelEvaluationParser.addArgument("modelEvaluationId"); + + Subparser displayEvaluationParser = subparsers.addParser("display_evaluation"); + displayEvaluationParser.addArgument("modelId"); + displayEvaluationParser.addArgument("filter").nargs("?").setDefault(""); + + Subparser deleteModelParser = subparsers.addParser("delete_model"); + deleteModelParser.addArgument("modelId"); + + Subparser getOperationStatusParser = subparsers.addParser("get_operation_status"); + getOperationStatusParser.addArgument("operationFullId"); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + if (ns.get("command").equals("create_model")) { + createModel(projectId, computeRegion, ns.getString("datasetId"), ns.getString("modelName")); + } + if (ns.get("command").equals("list_models")) { + listModels(projectId, computeRegion, ns.getString("filter")); + } + if (ns.get("command").equals("get_model")) { + getModel(projectId, computeRegion, ns.getString("modelId")); + } + if (ns.get("command").equals("list_model_evaluations")) { + listModelEvaluations( + projectId, computeRegion, ns.getString("modelId"), ns.getString("filter")); + } + if (ns.get("command").equals("get_model_evaluation")) { + getModelEvaluation( + projectId, computeRegion, ns.getString("modelId"), ns.getString("modelEvaluationId")); + } + if (ns.get("command").equals("delete_model")) { + deleteModel(projectId, computeRegion, ns.getString("modelId")); + } + if (ns.get("command").equals("get_operation_status")) { + getOperationStatus(ns.getString("operationFullId")); + } + if (ns.get("command").equals("display_evaluation")) { + displayEvaluation( + projectId, computeRegion, ns.getString("modelId"), ns.getString("filter")); + } + + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/language/automl/src/main/java/com/google/cloud/language/samples/PredictionApi.java b/language/automl/src/main/java/com/google/cloud/language/samples/PredictionApi.java new file mode 100644 index 00000000000..b374deded7e --- /dev/null +++ b/language/automl/src/main/java/com/google/cloud/language/samples/PredictionApi.java @@ -0,0 +1,121 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.language.samples; + +// Imports the Google Cloud client library +import com.google.cloud.automl.v1beta1.AnnotationPayload; +import com.google.cloud.automl.v1beta1.ExamplePayload; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.PredictResponse; +import com.google.cloud.automl.v1beta1.PredictionServiceClient; +import com.google.cloud.automl.v1beta1.TextSnippet; + +import java.io.IOException; +import java.io.PrintStream; + +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Natural Language API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.PredictionApi' -Dexec.args='predict + * [modelId] [path-to-text-file] [scoreThreshold]' + */ +public class PredictionApi { + + // [START automl_natural_language_predict] + /** + * Demonstrates using the AutoML client to classify the text content + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model which will be used for text classification. + * @param filePath the Local text file path of the content to be classified. + * @throws IOException on Input/Output errors. + */ + public static void predict( + String projectId, String computeRegion, String modelId, String filePath) throws IOException { + + // Create client for prediction service. + PredictionServiceClient predictionClient = PredictionServiceClient.create(); + + // Get full path of model + ModelName name = ModelName.of(projectId, computeRegion, modelId); + + // Read the file content for prediction. + String content = new String(Files.readAllBytes(Paths.get(filePath))); + + // Set the payload by giving the content and type of the file. + TextSnippet textSnippet = + TextSnippet.newBuilder().setContent(content).setMimeType("text/plain").build(); + ExamplePayload payload = ExamplePayload.newBuilder().setTextSnippet(textSnippet).build(); + + // params is additional domain-specific parameters. + // currently there is no additional parameters supported. + Map params = new HashMap(); + PredictResponse response = predictionClient.predict(name, payload, params); + + System.out.println("Prediction results:"); + for (AnnotationPayload annotationPayload : response.getPayloadList()) { + System.out.println("Predicted Class name :" + annotationPayload.getDisplayName()); + System.out.println( + "Predicted Class Score :" + annotationPayload.getClassification().getScore()); + } + } + // [END automl_natural_language_predict] + + public static void main(String[] args) throws IOException { + PredictionApi predictionApi = new PredictionApi(); + predictionApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws IOException { + ArgumentParser parser = + ArgumentParsers.newFor("PredictionApi") + .build() + .defaultHelp(true) + .description("Prediction API Operation"); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser predictParser = subparsers.addParser("predict"); + predictParser.addArgument("modelId"); + predictParser.addArgument("filePath"); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + if (ns.get("command").equals("predict")) { + predict(projectId, computeRegion, ns.getString("modelId"), ns.getString("filePath")); + } + + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/language/automl/src/test/java/com/google/cloud/language/samples/DatasetApiIT.java b/language/automl/src/test/java/com/google/cloud/language/samples/DatasetApiIT.java new file mode 100644 index 00000000000..35f2dbf2f71 --- /dev/null +++ b/language/automl/src/test/java/com/google/cloud/language/samples/DatasetApiIT.java @@ -0,0 +1,108 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.language.samples; + +import static com.google.common.truth.Truth.assertThat; +import static java.lang.Boolean.FALSE; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for Automl natural language "Dataset API" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class DatasetApiIT { + + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String BUCKET = PROJECT_ID + "-vcm"; + private static final String COMPUTE_REGION = "us-central1"; + private static final String DATASET_NAME = "test_language_dataset"; + private ByteArrayOutputStream bout; + private PrintStream out; + private DatasetApi app; + private String datasetId; + private String getdatasetId = "8477830379477056918"; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testCreateImportDeleteDataset() throws Exception { + // Act + DatasetApi.createDataset(PROJECT_ID, COMPUTE_REGION, DATASET_NAME, FALSE); + + // Assert + String got = bout.toString(); + datasetId = + bout.toString() + .split("\n")[0] + .split("/")[(bout.toString().split("\n")[0]).split("/").length - 1]; + assertThat(got).contains("Dataset id:"); + + // Act + DatasetApi.importData( + PROJECT_ID, COMPUTE_REGION, datasetId, "gs://" + BUCKET + "/happiness.csv"); + + // Assert + got = bout.toString(); + assertThat(got).contains("Dataset id:"); + + // Act + DatasetApi.deleteDataset(PROJECT_ID, COMPUTE_REGION, datasetId); + + // Assert + got = bout.toString(); + assertThat(got).contains("Dataset deleted."); + } + + @Test + public void testListDatasets() throws Exception { + // Act + DatasetApi.listDatasets(PROJECT_ID, COMPUTE_REGION, ""); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Dataset id:"); + } + + @Test + public void testGetDataset() throws Exception { + + // Act + DatasetApi.getDataset(PROJECT_ID, COMPUTE_REGION, getdatasetId); + + // Assert + String got = bout.toString(); + + assertThat(got).contains("Dataset id:"); + } +} diff --git a/language/automl/src/test/java/com/google/cloud/language/samples/ModelApiIT.java b/language/automl/src/test/java/com/google/cloud/language/samples/ModelApiIT.java new file mode 100644 index 00000000000..c142c3ef5b2 --- /dev/null +++ b/language/automl/src/test/java/com/google/cloud/language/samples/ModelApiIT.java @@ -0,0 +1,92 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.language.samples; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for vision "Model API" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class ModelApiIT { + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String COMPUTE_REGION = "us-central1"; + private ByteArrayOutputStream bout; + private PrintStream out; + private ModelApi app; + private String modelId; + private String modelIdGetevaluation = "342705131419266916"; + private String modelEvaluationId = "3666189665418739402"; + + @Before + public void setUp() { + + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testModelApi() throws Exception { + // Act + ModelApi.listModels(PROJECT_ID, COMPUTE_REGION, ""); + + // Assert + String got = bout.toString(); + modelId = got.split("\n")[1].split("/")[got.split("\n")[1].split("/").length - 1]; + assertThat(got).contains("Model id:"); + + // Act + ModelApi.getModel(PROJECT_ID, COMPUTE_REGION, modelId); + + // Assert + got = bout.toString(); + assertThat(got).contains("Model name:"); + + // Act + ModelApi.listModelEvaluations(PROJECT_ID, COMPUTE_REGION, modelId, ""); + + // Assert + got = bout.toString(); + assertThat(got).contains("name:"); + } + + @Test + public void testGetModelEvaluation() throws Exception { + + // Act + ModelApi.getModelEvaluation( + PROJECT_ID, COMPUTE_REGION, modelIdGetevaluation, modelEvaluationId); + + // Assert + String got = bout.toString(); + assertThat(got).contains("name:"); + } +} diff --git a/language/automl/src/test/java/com/google/cloud/language/samples/PredictionApiIT.java b/language/automl/src/test/java/com/google/cloud/language/samples/PredictionApiIT.java new file mode 100644 index 00000000000..3549eb244ee --- /dev/null +++ b/language/automl/src/test/java/com/google/cloud/language/samples/PredictionApiIT.java @@ -0,0 +1,63 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.language.samples; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for vision "PredictionAPI" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class PredictionApiIT { + private static final String COMPUTE_REGION = "us-central1"; + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String modelId = "342705131419266916"; + private static final String filePath = "./resources/input.txt"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PredictionApi app; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testPredict() throws Exception { + // Act + PredictionApi.predict(PROJECT_ID, COMPUTE_REGION, modelId, filePath); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Prediction results:"); + } +} diff --git a/translate/automl/README.md b/translate/automl/README.md new file mode 100644 index 00000000000..e2248a17de7 --- /dev/null +++ b/translate/automl/README.md @@ -0,0 +1,86 @@ +# AutoML Translate Sample + + +Open in Cloud Shell + +[Google Cloud Translate API][translate] provides feature AutoML. +This API is part of the larger collection of Cloud Machine Learning APIs. + +This sample Java application demonstrates how to access the Cloud Translate AutoML API +using the [Google Cloud Client Library for Java][google-cloud-java]. + +## Set the environment variables + +PROJECT_ID = [Id of the project] +REGION_NAME = [Region name] + +## Build the sample + +Install [Maven](http://maven.apache.org/). + +Build your project with: + +``` +mvn clean package +``` + +### Dataset API + +#### Create a new dataset +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.DatasetApi" -Dexec.args="create_dataset test_dataset" +``` + +#### List datasets +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.DatasetApi" -Dexec.args="list_datasets" +``` + +#### Get dataset +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.DatasetApi" -Dexec.args="get_dataset [dataset-id]" +``` + +#### Import data +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.DatasetApi" -Dexec.args="import_data gs://java-docs-samples-testing/en-ja.csv" +``` + +### Model API + +#### Create Model +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.ModelApi" -Dexec.args="create_model test_model" +``` + +#### List Models +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.ModelApi" -Dexec.args="list_models" +``` + +#### Get Model +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.ModelApi" -Dexec.args="get_model [model-id]" +``` + +#### List Model Evaluations +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.ModelApi" -Dexec.args="list_model_evaluation [model-id]" +``` + +#### Get Model Evaluation +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.ModelApi" -Dexec.args="get_model_evaluation [model-id] [model-evaluation-id]" +``` + +#### Delete Model +``` +mvn exec:java-Dexec.mainClass="com.google.cloud.translate.samples.ModelApi" -Dexec.args="delete_model [model-id]" +``` +### Predict API + +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.translate.samples.PredictApi" -Dexec.args="predict [model-id] ./resources/input.txt" +``` + + diff --git a/translate/automl/pom.xml b/translate/automl/pom.xml new file mode 100644 index 00000000000..488081b3fe2 --- /dev/null +++ b/translate/automl/pom.xml @@ -0,0 +1,153 @@ + + + 4.0.0 + com.google.cloud.translate.automl + translate-automl + jar + + + + com.google.cloud.samples + shared-configuration + 1.0.9 + + + + 1.8 + 1.8 + UTF-8 + + + + + + com.google.cloud + google-cloud-automl + 0.55.0-beta + + + net.sourceforge.argparse4j + argparse4j + 0.8.1 + + + + + + junit + junit + 4.12 + test + + + com.google.truth + truth + 0.41 + test + + + + + + DatasetApi + + + DatasetApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.translate.automl.DatasetApi + false + + + + + + + ModelApi + + + ModelApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.translate.automl.ModelApi + false + + + + + + + PredictionApi + + + PredictionApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.translate.automl.PredictionApi + false + + + + + + + diff --git a/translate/automl/resources/input.txt b/translate/automl/resources/input.txt new file mode 100644 index 00000000000..5aecd6590fc --- /dev/null +++ b/translate/automl/resources/input.txt @@ -0,0 +1 @@ +Tell me how this ends \ No newline at end of file diff --git a/translate/automl/src/main/java/com/google/cloud/translate/automl/DatasetApi.java b/translate/automl/src/main/java/com/google/cloud/translate/automl/DatasetApi.java new file mode 100644 index 00000000000..03cc96aa45f --- /dev/null +++ b/translate/automl/src/main/java/com/google/cloud/translate/automl/DatasetApi.java @@ -0,0 +1,314 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.translate.automl; + +// Imports the Google Cloud client library +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.Dataset; +import com.google.cloud.automl.v1beta1.DatasetName; +import com.google.cloud.automl.v1beta1.GcsSource; +import com.google.cloud.automl.v1beta1.GcsSource.Builder; +import com.google.cloud.automl.v1beta1.InputConfig; +import com.google.cloud.automl.v1beta1.ListDatasetsRequest; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.TranslationDatasetMetadata; +import com.google.protobuf.Empty; + +import java.io.IOException; +import java.io.PrintStream; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Translate API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.translate.samples.DatasetAPI' -Dexec.args='create_dataset + * test_dataset' + */ +public class DatasetApi { + + // [START automl_translate_create_dataset] + /** + * Demonstrates using the AutoML client to create a dataset + * + * @param projectId the Google Cloud Project ID. + * @param computeRegion the Region name. (e.g., "us-central1"). + * @param datasetName the name of the dataset to be created. + * @param source the Source language + * @param target the Target language + * @throws IOException on Input/Output errors. + */ + public static void createDataset( + String projectId, String computeRegion, String datasetName, String source, String target) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Specify the source and target language. + TranslationDatasetMetadata translationDatasetMetadata = + TranslationDatasetMetadata.newBuilder() + .setSourceLanguageCode(source) + .setTargetLanguageCode(target) + .build(); + + // Set dataset name and dataset metadata. + Dataset myDataset = + Dataset.newBuilder() + .setDisplayName(datasetName) + .setTranslationDatasetMetadata(translationDatasetMetadata) + .build(); + + // Create a dataset with the dataset metadata in the region. + Dataset dataset = client.createDataset(projectLocation, myDataset); + + // Display the dataset information. + System.out.println(String.format("Dataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Translation dataset Metadata:"); + System.out.println( + String.format( + "\tSource language code: %s", + dataset.getTranslationDatasetMetadata().getSourceLanguageCode())); + System.out.println( + String.format( + "\tTarget language code: %s", + dataset.getTranslationDatasetMetadata().getTargetLanguageCode())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + // [END automl_translation_create_dataset] + + // [START automl_translation_list_datasets] + /** + * Demonstrates using the AutoML client to list all datasets. + * + * @param projectId the Google Cloud Project ID. + * @param computeRegion the Region name. (e.g., "us-central1"). + * @param filter the Filter expression. + * @throws Exception on AutoML Client errors + */ + public static void listDatasets(String projectId, String computeRegion, String filter) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + ListDatasetsRequest request = + ListDatasetsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + // List all the datasets available in the region by applying filter. + System.out.println("List of datasets:"); + for (Dataset dataset : client.listDatasets(request).iterateAll()) { + // Display the dataset information + System.out.println(String.format("\nDataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Translation dataset metadata:"); + System.out.println( + String.format( + "\tSource language code: %s", + dataset.getTranslationDatasetMetadata().getSourceLanguageCode())); + System.out.println( + String.format( + "\tTarget language code: %s", + dataset.getTranslationDatasetMetadata().getTargetLanguageCode())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + } + // [END automl_translation_list_datasets] + + // [START automl_translation_get_dataset] + /** + * Demonstrates using the AutoML client to get a dataset by ID. + * + * @param projectId the Google Cloud Project ID. + * @param computeRegion the Region name. (e.g., "us-central1"). + * @param datasetId the Id of the dataset. + * @throws Exception on AutoML Client errors + */ + public static void getDataset(String projectId, String computeRegion, String datasetId) + throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Get all the information about a given dataset. + Dataset dataset = client.getDataset(datasetFullId); + + // Display the dataset information + System.out.println(String.format("Dataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Translation dataset metadata:"); + System.out.println( + String.format( + "\tSource language code: %s", + dataset.getTranslationDatasetMetadata().getSourceLanguageCode())); + System.out.println( + String.format( + "\tTarget language code: %s", + dataset.getTranslationDatasetMetadata().getTargetLanguageCode())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + // [END automl_translation_get_dataset] + + // [START automl_translation_import_data] + /** + * Import sentence pairs to the dataset. + * + * @param projectId the Google Cloud Project ID. + * @param computeRegion the Region name. (e.g., "us-central1"). + * @param datasetId the Id of the dataset. + * @param path the remote Path of the training data csv file. + * @throws Exception on AutoML Client errors + */ + public static void importData( + String projectId, String computeRegion, String datasetId, String path) throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + Builder gcsSource = GcsSource.newBuilder(); + + // Get multiple Google Cloud Storage URIs to import data from + String[] inputUris = path.split(","); + for (String inputUri : inputUris) { + gcsSource.addInputUris(inputUri); + } + + // Import data from the input URI + InputConfig inputConfig = InputConfig.newBuilder().setGcsSource(gcsSource).build(); + System.out.println("Processing import..."); + + Empty response = client.importDataAsync(datasetFullId, inputConfig).get(); + System.out.println(String.format("Dataset imported. %s", response)); + } + // [END automl_translation_import_data] + + // [START automl_translation_delete_dataset] + /** + * Delete a dataset. + * + * @param projectId the Google Cloud Project ID. + * @param computeRegion the Region name. (e.g., "us-central1"). + * @param datasetId the Id of the dataset. + * @throws Exception on AutoML Client errors + */ + public static void deleteDataset(String projectId, String computeRegion, String datasetId) + throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Delete a dataset. + Empty response = client.deleteDatasetAsync(datasetFullId).get(); + + System.out.println(String.format("Dataset deleted. %s", response)); + } + // [END automl_translation_delete_dataset] + + public static void main(String[] args) throws Exception { + DatasetApi datasetApi = new DatasetApi(); + datasetApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws Exception { + ArgumentParser parser = ArgumentParsers.newFor("").build(); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser createDatasetParser = subparsers.addParser("create_dataset"); + createDatasetParser.addArgument("datasetName"); + createDatasetParser.addArgument("source"); + createDatasetParser.addArgument("target"); + + Subparser listDatasetParser = subparsers.addParser("list_datasets"); + listDatasetParser.addArgument("filter").nargs("?").setDefault("translation_dataset_metadata:*"); + + Subparser getDatasetParser = subparsers.addParser("get_dataset"); + getDatasetParser.addArgument("datasetId"); + + Subparser importDataParser = subparsers.addParser("import_data"); + importDataParser.addArgument("datasetId"); + importDataParser.addArgument("path"); + + Subparser deleteDatasetParser = subparsers.addParser("delete_dataset"); + deleteDatasetParser.addArgument("datasetId"); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + if (ns.get("command").equals("create_dataset")) { + createDataset( + projectId, + computeRegion, + ns.getString("datasetName"), + ns.getString("source"), + ns.getString("target")); + } + if (ns.get("command").equals("list_datasets")) { + listDatasets(projectId, computeRegion, ns.getString("filter")); + } + if (ns.get("command").equals("get_dataset")) { + getDataset(projectId, computeRegion, ns.getString("datasetId")); + } + if (ns.get("command").equals("import_data")) { + importData(projectId, computeRegion, ns.getString("datasetId"), ns.getString("path")); + } + if (ns.get("command").equals("delete_dataset")) { + deleteDataset(projectId, computeRegion, ns.getString("datasetId")); + } + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/translate/automl/src/main/java/com/google/cloud/translate/automl/ModelApi.java b/translate/automl/src/main/java/com/google/cloud/translate/automl/ModelApi.java new file mode 100644 index 00000000000..c720c3d5f40 --- /dev/null +++ b/translate/automl/src/main/java/com/google/cloud/translate/automl/ModelApi.java @@ -0,0 +1,341 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.translate.automl; + +// Imports the Google Cloud client library +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.ListModelEvaluationsRequest; +import com.google.cloud.automl.v1beta1.ListModelsRequest; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.Model; +import com.google.cloud.automl.v1beta1.ModelEvaluation; +import com.google.cloud.automl.v1beta1.ModelEvaluationName; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; +import com.google.cloud.automl.v1beta1.TranslationModelMetadata; +import com.google.longrunning.Operation; +import com.google.protobuf.Empty; + +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Translate API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.example.translate.ModelApi' -Dexec.args='create_model [datasetId] + * test_model' + */ +public class ModelApi { + + // [START automl_translation_create_model] + /** + * Demonstrates using the AutoML client to create a model. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param dataSetId the Id of the dataset to which model is created. + * @param modelName the Name of the model. + * @throws Exception on AutoML Client errors + */ + public static void createModel( + String projectId, String computeRegion, String dataSetId, String modelName) throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Set model metadata. + TranslationModelMetadata translationModelMetadata = + TranslationModelMetadata.newBuilder().setBaseModel("").build(); + + // Set model name, dataset and metadata. + Model myModel = + Model.newBuilder() + .setDisplayName(modelName) + .setDatasetId(dataSetId) + .setTranslationModelMetadata(translationModelMetadata) + .build(); + + // Create a model with the model metadata in the region. + OperationFuture response = + client.createModelAsync(projectLocation, myModel); + + System.out.println( + String.format("Training operation name: %s", response.getInitialFuture().get().getName())); + System.out.println("Training started..."); + } + // [END automl_translation_create_model] + + // [START automl_translation_list_models] + /** + * Demonstrates using the AutoML client to list all models. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param filter the filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listModels(String projectId, String computeRegion, String filter) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Create list models request. + ListModelsRequest listModlesRequest = + ListModelsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + // List all the models available in the region by applying filter. + System.out.println("List of models:"); + for (Model model : client.listModels(listModlesRequest).iterateAll()) { + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } + } + // [END automl_translation_list_models] + + // [START automl_translation_get_model] + /** + * Demonstrates using the AutoML client to get model details. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @throws IOException on Input/Output errors. + */ + public static void getModel(String projectId, String computeRegion, String modelId) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Get complete detail of the model. + Model model = client.getModel(modelFullId); + + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } + // [END automl_translation_get_model] + + // [START automl_translation_list_model_evaluations] + /** + * Demonstrates using the AutoML client to list model evaluations. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param filter the filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listModelEvaluations( + String projectId, String computeRegion, String modelId, String filter) throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Create list model evaluations request + ListModelEvaluationsRequest modelEvaluationsrequest = + ListModelEvaluationsRequest.newBuilder() + .setParent(modelFullId.toString()) + .setFilter(filter) + .build(); + + // List all the model evaluations in the model by applying filter. + System.out.println("List of model evaluations:"); + for (ModelEvaluation element : + client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + System.out.println(element); + } + } + // [END automl_translation_list_model_evaluations] + + // [START automl_translation_get_model_evaluation] + /** + * Demonstrates using the AutoML client to get model evaluations. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param modelEvaluationId the Id of your model evaluation. + * @throws IOException on Input/Output errors. + */ + public static void getModelEvaluation( + String projectId, String computeRegion, String modelId, String modelEvaluationId) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model evaluation. + ModelEvaluationName modelEvaluationFullId = + ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); + + // Get complete detail of the model evaluation. + ModelEvaluation response = client.getModelEvaluation(modelEvaluationFullId); + + System.out.println(response); + } + // [END automl_translation_get_model_evaluation] + + // [START automl_translation_delete_model] + /** + * Demonstrates using the AutoML client to delete a model. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @throws Exception on AutoML Client errors + */ + public static void deleteModel(String projectId, String computeRegion, String modelId) + throws InterruptedException, ExecutionException, IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Delete a model. + Empty response = client.deleteModelAsync(modelFullId).get(); + + System.out.println("Model deletion started..."); + } + // [END automl_translation_delete_model] + + // [START automl_translation_get_operation_status] + /** + * Demonstrates using the AutoML client to get operation status. + * + * @param operationFullId Full name of a operation. For example, the name of your operation is + * projects/[projectId]/locations/us-central1/operations/[operationId]. + * @throws IOException on Input/Output errors. + */ + private static void getOperationStatus(String operationFullId) throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the latest state of a long-running operation. + Operation response = client.getOperationsClient().getOperation(operationFullId); + + System.out.println(String.format("Operation status: %s", response)); + } + // [END automl_translation_get_operation_status] + + public static void main(String[] args) throws Exception { + ModelApi modelApi = new ModelApi(); + modelApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws Exception { + + ArgumentParser parser = + ArgumentParsers.newFor("ModelApi") + .build() + .defaultHelp(true) + .description("Model API operations"); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser createModelParser = subparsers.addParser("create_model"); + createModelParser.addArgument("datasetId"); + createModelParser.addArgument("modelName"); + + Subparser listModelParser = subparsers.addParser("list_models"); + listModelParser.addArgument("filter").nargs("?").setDefault(""); + + Subparser getModelParser = subparsers.addParser("get_model"); + getModelParser.addArgument("modelId"); + + Subparser listModelEvaluationsParser = subparsers.addParser("list_model_evaluations"); + listModelEvaluationsParser.addArgument("modelId"); + listModelEvaluationsParser.addArgument("filter").nargs("?").setDefault(""); + + Subparser getModelEvaluationParser = subparsers.addParser("get_model_evaluation"); + getModelEvaluationParser.addArgument("modelId"); + getModelEvaluationParser.addArgument("modelEvaluationId"); + + Subparser deleteModelParser = subparsers.addParser("delete_model"); + deleteModelParser.addArgument("modelId"); + + Subparser getOperationStatusParser = subparsers.addParser("get_operation_status"); + getOperationStatusParser.addArgument("operationFullId"); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + if (ns.get("command").equals("create_model")) { + createModel(projectId, computeRegion, ns.getString("datasetId"), ns.getString("modelName")); + } + if (ns.get("command").equals("list_models")) { + listModels(projectId, computeRegion, ns.getString("filter")); + } + if (ns.get("command").equals("get_model")) { + getModel(projectId, computeRegion, ns.getString("modelId")); + } + if (ns.get("command").equals("list_model_evaluations")) { + listModelEvaluations( + projectId, computeRegion, ns.getString("modelId"), ns.getString("filter")); + } + if (ns.get("command").equals("get_model_evaluation")) { + getModelEvaluation( + projectId, computeRegion, ns.getString("modelId"), ns.getString("modelEvaluationId")); + } + if (ns.get("command").equals("delete_model")) { + deleteModel(projectId, computeRegion, ns.getString("modelId")); + } + if (ns.get("command").equals("get_operation_status")) { + getOperationStatus(ns.getString("operationFullId")); + } + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/translate/automl/src/main/java/com/google/cloud/translate/automl/PredictionApi.java b/translate/automl/src/main/java/com/google/cloud/translate/automl/PredictionApi.java new file mode 100644 index 00000000000..d3b1170897f --- /dev/null +++ b/translate/automl/src/main/java/com/google/cloud/translate/automl/PredictionApi.java @@ -0,0 +1,140 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This application demonstrates how to perform basic operations on prediction + * with the Google AutoML Vision API. + * + * For more information, the documentation at + * https://cloud.google.com/vision/automl/docs. + */ + +package com.google.cloud.translate.automl; + +// Imports the Google Cloud client library +import com.google.cloud.automl.v1beta1.ExamplePayload; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.PredictResponse; +import com.google.cloud.automl.v1beta1.PredictionServiceClient; + +import com.google.cloud.automl.v1beta1.TextSnippet; +import java.io.IOException; +import java.io.PrintStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Translate API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.PredictionApi' -Dexec.args='predict + * [modelId] [path-to-image] [scoreThreshold]' + */ +public class PredictionApi { + + // [START automl_translation_predict] + + /** + * Demonstrates using the AutoML client to predict an image. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model which will be used for text classification. + * @param filePath the Local text file path of the content to be classified. + * @param translationAllowFallback set to true to use a Google translation. + * @throws IOException on Input/Output errors. + */ + public static void predict( + String projectId, + String computeRegion, + String modelId, + String filePath, + boolean translationAllowFallback) + throws IOException { + // Instantiate client for prediction service. + PredictionServiceClient predictionClient = PredictionServiceClient.create(); + + // Get the full path of the model. + ModelName name = ModelName.of(projectId, computeRegion, modelId); + + // Read the file content for translation. + String content = new String(Files.readAllBytes(Paths.get(filePath))); + + TextSnippet textSnippet = TextSnippet.newBuilder().setContent(content).build(); + + // Set the payload by giving the content of the file. + ExamplePayload payload = ExamplePayload.newBuilder().setTextSnippet(textSnippet).build(); + + // Additional parameters that can be provided for prediction + Map params = new HashMap<>(); + if (translationAllowFallback) { + params.put("translation_allow_fallback", "True");//Allow Google Translation Model + } + + PredictResponse response = predictionClient.predict(name, payload, params); + TextSnippet translatedContent = response.getPayload(0).getTranslation().getTranslatedContent(); + + System.out.println(String.format("Translated Content: %s", translatedContent.getContent())); + } + // [END automl_translation_predict] + + public static void main(String[] args) throws IOException { + PredictionApi predictApi = new PredictionApi(); + predictApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws IOException { + ArgumentParser parser = ArgumentParsers.newFor("PredictionApi") + .build() + .defaultHelp(true) + .description("Prediction API Operation"); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser predictParser = subparsers.addParser("predict"); + predictParser.addArgument("modelId"); + predictParser.addArgument("filePath"); + predictParser + .addArgument("translationAllowFallback") + .nargs("?") + .type(Boolean.class) + .setDefault(Boolean.FALSE); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + if (ns.get("command").equals("predict")) { + predict( + projectId, + computeRegion, + ns.getString("modelId"), + ns.getString("filePath"), + ns.getBoolean("translationAllowFallback")); + } + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/translate/automl/src/test/java/com/google/cloud/translate/automl/DatasetApiIT.java b/translate/automl/src/test/java/com/google/cloud/translate/automl/DatasetApiIT.java new file mode 100644 index 00000000000..2f47e55968a --- /dev/null +++ b/translate/automl/src/test/java/com/google/cloud/translate/automl/DatasetApiIT.java @@ -0,0 +1,106 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.translate.automl; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for Automl translation "Dataset API" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class DatasetApiIT { + + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String BUCKET = PROJECT_ID + "-vcm"; + private static final String COMPUTE_REGION = "us-central1"; + private static final String DATASET_NAME = "test_translate_dataset"; + private ByteArrayOutputStream bout; + private PrintStream out; + private DatasetApi app; + private String datasetId; + private String getdatasetId = "3946265060617537378"; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testCreateImportDeleteDataset() throws Exception { + // Act + DatasetApi.createDataset(PROJECT_ID, COMPUTE_REGION, DATASET_NAME, "en", "ja"); + + // Assert + String got = bout.toString(); + datasetId = + bout.toString() + .split("\n")[0] + .split("/")[(bout.toString().split("\n")[0]).split("/").length - 1]; + assertThat(got).contains("Dataset id:"); + + // Act + DatasetApi.importData(PROJECT_ID, COMPUTE_REGION, datasetId, "gs://" + BUCKET + "/en-ja.csv"); + + // Assert + got = bout.toString(); + assertThat(got).contains("Dataset id:"); + + // Act + DatasetApi.deleteDataset(PROJECT_ID, COMPUTE_REGION, datasetId); + + // Assert + got = bout.toString(); + assertThat(got).contains("Dataset deleted."); + } + + @Test + public void testListDataset() throws Exception { + // Act + DatasetApi.listDatasets(PROJECT_ID, COMPUTE_REGION, "translation_dataset_metadata:*"); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Dataset id:"); + } + + @Test + public void testGetDataset() throws Exception { + + // Act + DatasetApi.getDataset(PROJECT_ID, COMPUTE_REGION, getdatasetId); + + // Assert + String got = bout.toString(); + + assertThat(got).contains("Dataset id:"); + } +} diff --git a/translate/automl/src/test/java/com/google/cloud/translate/automl/ModelApiIT.java b/translate/automl/src/test/java/com/google/cloud/translate/automl/ModelApiIT.java new file mode 100644 index 00000000000..0ebfeed339b --- /dev/null +++ b/translate/automl/src/test/java/com/google/cloud/translate/automl/ModelApiIT.java @@ -0,0 +1,89 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.translate.automl; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for translation "Model API" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class ModelApiIT { + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String COMPUTE_REGION = "us-central1"; + private ByteArrayOutputStream bout; + private PrintStream out; + private ModelApi app; + private String modelId; + private String modelEvaluationId; + + @Before + public void setUp() { + + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testModelApi() throws Exception { + // Act + ModelApi.listModels(PROJECT_ID, COMPUTE_REGION, ""); + + // Assert + String got = bout.toString(); + modelId = got.split("\n")[1].split("/")[got.split("\n")[1].split("/").length - 1]; + assertThat(got).contains("Model id:"); + + // Act + ModelApi.getModel(PROJECT_ID, COMPUTE_REGION, modelId); + + // Assert + got = bout.toString(); + assertThat(got).contains("Model name:"); + + // Act + ModelApi.listModelEvaluations(PROJECT_ID, COMPUTE_REGION, modelId, ""); + + // Assert + got = bout.toString(); + modelEvaluationId = got.split("List of model evaluations:")[1].split("\"")[1].split("/")[7]; + assertThat(got).contains("name:"); + + // Act + ModelApi.getModelEvaluation(PROJECT_ID, COMPUTE_REGION, modelId, modelEvaluationId); + + // Assert + got = bout.toString(); + assertThat(got).contains("name:"); + + } +} + diff --git a/translate/automl/src/test/java/com/google/cloud/translate/automl/PredictionApiIT.java b/translate/automl/src/test/java/com/google/cloud/translate/automl/PredictionApiIT.java new file mode 100644 index 00000000000..962a0fd0802 --- /dev/null +++ b/translate/automl/src/test/java/com/google/cloud/translate/automl/PredictionApiIT.java @@ -0,0 +1,64 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.translate.automl; + +import static com.google.common.truth.Truth.assertThat; +import static java.lang.Boolean.FALSE; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for translation "PredictionAPI" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class PredictionApiIT { + private static final String COMPUTE_REGION = "us-central1"; + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String modelId = "2188848820815848149"; + private static final String filePath = "./resources/input.txt"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PredictionApi app; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testPredict() throws Exception { + // Act + PredictionApi.predict(PROJECT_ID, COMPUTE_REGION, modelId, filePath,FALSE); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Translated Content"); + } +} diff --git a/vision/automl/README.md b/vision/automl/README.md new file mode 100644 index 00000000000..24a73c48ece --- /dev/null +++ b/vision/automl/README.md @@ -0,0 +1,87 @@ +# AutoML Sample + + +Open in Cloud Shell + +[Google Cloud Vision API][vision] provides feature detection for images. +This API is part of the larger collection of Cloud Machine Learning APIs. + +This sample Java application demonstrates how to access the Cloud Vision API +using the [Google Cloud Client Library for Java][google-cloud-java]. + +[vision]: https://cloud.google.com/vision/docs/ +[google-cloud-java]: https://github.com/GoogleCloudPlatform/google-cloud-java + +## Set the environment variables + +PROJECT_ID = [Id of the project] +REGION_NAME = [Region name] +## Build the sample + +Install [Maven](http://maven.apache.org/). + +Build your project with: + +``` +mvn clean package +``` + +### Dataset API + +#### Create a new dataset +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.DatasetApi" -Dexec.args="create_dataset test_dataset" +``` + +#### List datasets +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.DatasetApi" -Dexec.args="list_datasets" +``` + +#### Get dataset +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.DatasetApi" -Dexec.args="get_dataset [dataset-id]" +``` + +#### Import data +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.DatasetApi" -Dexec.args="import_data gs://java-docs-samples-testing/flower_traindata.csv" +``` + +### Model API + +#### Create Model +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="create_model test_model" +``` + +#### List Models +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="list_models" +``` + +#### Get Model +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="get_model [model-id]" +``` + +#### List Model Evaluations +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="list_model_evaluation [model-id]" +``` + +#### Get Model Evaluation +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="get_model_evaluation [model-id] [model-evaluation-id]" +``` + +#### Delete Model +``` +mvn exec:java-Dexec.mainClass="com.google.cloud.vision.samples.automl.ModeltApi" -Dexec.args="delete_model [model-id]" +``` +### Predict API + +``` +mvn exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.PredictApi" -Dexec.args="predict [model-id] ./resources/dandelion.jpg 0.7" +``` + diff --git a/vision/automl/pom.xml b/vision/automl/pom.xml new file mode 100644 index 00000000000..5b5dd838624 --- /dev/null +++ b/vision/automl/pom.xml @@ -0,0 +1,154 @@ + + + 4.0.0 + com.example.vision + vision-automl + jar + + + + com.google.cloud.samples + shared-configuration + 1.0.9 + + + + 1.8 + 1.8 + UTF-8 + + + + + + com.google.cloud + google-cloud-automl + 0.55.0-beta + + + net.sourceforge.argparse4j + argparse4j + 0.8.1 + + + + + + junit + junit + 4.12 + test + + + + com.google.truth + truth + 0.41 + test + + + + + + DatasetApi + + + DatasetApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.vision.samples.automl.DatasetApi + false + + + + + + + ModelApi + + + ModelApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.vision.samples.automl.ModelApi + false + + + + + + + PredictApi + + + PredictApi + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + java + + + + + com.google.cloud.vision.samples.automl.PredictApi + false + + + + + + + diff --git a/vision/automl/resources/dandelion.jpg b/vision/automl/resources/dandelion.jpg new file mode 100644 index 00000000000..326e4c1bf53 Binary files /dev/null and b/vision/automl/resources/dandelion.jpg differ diff --git a/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/DatasetApi.java b/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/DatasetApi.java new file mode 100644 index 00000000000..5b07e542749 --- /dev/null +++ b/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/DatasetApi.java @@ -0,0 +1,347 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision.samples.automl; + +// Imports the Google Cloud client library +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationType; +import com.google.cloud.automl.v1beta1.Dataset; +import com.google.cloud.automl.v1beta1.DatasetName; +import com.google.cloud.automl.v1beta1.GcsDestination; +import com.google.cloud.automl.v1beta1.GcsSource; +import com.google.cloud.automl.v1beta1.ImageClassificationDatasetMetadata; +import com.google.cloud.automl.v1beta1.InputConfig; +import com.google.cloud.automl.v1beta1.ListDatasetsRequest; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.OutputConfig; +import com.google.protobuf.Empty; + +import java.io.IOException; +import java.io.PrintStream; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Vision API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.DatasetAPI' -Dexec.args='create_dataset + * test_dataset' + */ +public class DatasetApi { + + // [START automl_vision_create_dataset] + /** + * Demonstrates using the AutoML client to create a dataset + * + * @param projectId the Google Cloud Project ID. + * @param computeRegion the Region name. (e.g., "us-central1") + * @param datasetName the name of the dataset to be created. + * @param multiLabel the type of classification problem. Set to FALSE by default. False - + * MULTICLASS , True - MULTILABEL + * @throws IOException on Input/Output errors. + */ + public static void createDataset( + String projectId, String computeRegion, String datasetName, Boolean multiLabel) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Classification type assigned based on multiLabel value. + ClassificationType classificationType = + multiLabel ? ClassificationType.MULTILABEL : ClassificationType.MULTICLASS; + + // Specify the image classification type for the dataset. + ImageClassificationDatasetMetadata imageClassificationDatasetMetadata = + ImageClassificationDatasetMetadata.newBuilder() + .setClassificationType(classificationType) + .build(); + + // Set dataset with dataset name and set the dataset metadata. + Dataset myDataset = + Dataset.newBuilder() + .setDisplayName(datasetName) + .setImageClassificationDatasetMetadata(imageClassificationDatasetMetadata) + .build(); + + // Create dataset with the dataset metadata in the region. + Dataset dataset = client.createDataset(projectLocation, myDataset); + + // Display the dataset information + System.out.println(String.format("Dataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Image classification dataset specification:"); + System.out.print(String.format("\t%s", dataset.getImageClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + // [END automl_vision_create_dataset] + + // [START automl_vision_list_datasets] + /** + * Demonstrates using the AutoML client to list all datasets. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param filter the Filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listDatasets(String projectId, String computeRegion, String filter) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Build the List datasets request + ListDatasetsRequest request = + ListDatasetsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + // List all the datasets available in the region by applying the filter. + System.out.print("List of datasets:"); + for (Dataset dataset : client.listDatasets(request).iterateAll()) { + // Display the dataset information + System.out.println(String.format("\nDataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Image classification dataset specification:"); + System.out.print(String.format("\t%s", dataset.getImageClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + } + // [END automl_vision_list_datasets] + + // [START automl_vision_get_dataset] + /** + * Demonstrates using the AutoML client to get a dataset by ID. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param datasetId the Id of the dataset. + * @throws IOException on Input/Output errors. + */ + public static void getDataset(String projectId, String computeRegion, String datasetId) + throws IOException { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Get all the information about a given dataset. + Dataset dataset = client.getDataset(datasetFullId); + + // Display the dataset information. + System.out.println(String.format("Dataset name: %s", dataset.getName())); + System.out.println( + String.format( + "Dataset id: %s", + dataset.getName().split("/")[dataset.getName().split("/").length - 1])); + System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName())); + System.out.println("Image classification dataset specification:"); + System.out.print(String.format("\t%s", dataset.getImageClassificationDatasetMetadata())); + System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount())); + System.out.println("Dataset create time:"); + System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos())); + } + // [END automl_vision_get_dataset] + + // [START automl_vision_import_data] + /** + * Demonstrates using the AutoML client to import labeled images. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param datasetId the Id of the dataset to which the training data will be imported. + * @param path the Google Cloud Storage URIs. Target files must be in AutoML vision CSV format. + * @throws Exception on AutoML Client errors + */ + public static void importData( + String projectId, String computeRegion, String datasetId, String path) throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + GcsSource.Builder gcsSource = GcsSource.newBuilder(); + + // Get multiple training data files to be imported + String[] inputUris = path.split(","); + for (String inputUri : inputUris) { + gcsSource.addInputUris(inputUri); + } + + // Import data from the input URI + InputConfig inputConfig = InputConfig.newBuilder().setGcsSource(gcsSource).build(); + System.out.println("Processing import..."); + Empty response = client.importDataAsync(datasetFullId.toString(), inputConfig).get(); + System.out.println(String.format("Dataset imported. %s", response)); + } + // [END automl_vision_import_data] + + // [START automl_vision_export_data] + /** + * Demonstrates using the AutoML client to export a dataset to a Google Cloud Storage bucket. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param datasetId the Id of the dataset. + * @param gcsUri the Destination URI (Google Cloud Storage) + * @throws Exception on AutoML Client errors + */ + public static void exportData( + String projectId, String computeRegion, String datasetId, String gcsUri) throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Set the output URI + GcsDestination gcsDestination = GcsDestination.newBuilder().setOutputUriPrefix(gcsUri).build(); + + // Export the dataset to the output URI. + OutputConfig outputConfig = OutputConfig.newBuilder().setGcsDestination(gcsDestination).build(); + System.out.println("Processing export..."); + + Empty response = client.exportDataAsync(datasetFullId, outputConfig).get(); + System.out.println(String.format("Dataset exported. %s", response)); + } + // [END automl_vision_export_data] + + // [START automl_vision_delete_dataset] + /** + * Delete a dataset. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param datasetId the Id of the dataset. + * @throws Exception on AutoML Client errors + */ + public static void deleteDataset(String projectId, String computeRegion, String datasetId) + throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // Get the complete path of the dataset. + DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId); + + // Delete a dataset. + Empty response = client.deleteDatasetAsync(datasetFullId).get(); + + System.out.println(String.format("Dataset deleted. %s", response)); + } + // [END automl_vision_delete_dataset] + + public static void main(String[] args) throws Exception { + DatasetApi datasetApi = new DatasetApi(); + datasetApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws Exception { + ArgumentParser parser = + ArgumentParsers.newFor("DatasetApi") + .build() + .defaultHelp(true) + .description("Dataset API operations."); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser createDatasetParser = subparsers.addParser("create_dataset"); + createDatasetParser.addArgument("datasetName"); + createDatasetParser + .addArgument("multiLabel") + .nargs("?") + .type(Boolean.class) + .choices(Boolean.FALSE, Boolean.TRUE) + .setDefault(Boolean.FALSE); + + Subparser listDatasetsParser = subparsers.addParser("list_datasets"); + listDatasetsParser + .addArgument("filter") + .nargs("?") + .setDefault("imageClassificationDatasetMetadata:*"); + + Subparser getDatasetParser = subparsers.addParser("get_dataset"); + getDatasetParser.addArgument("datasetId"); + + Subparser importDataParser = subparsers.addParser("import_data"); + importDataParser.addArgument("datasetId"); + importDataParser.addArgument("path"); + + Subparser exportDataParser = subparsers.addParser("export_data"); + exportDataParser.addArgument("datasetId"); + exportDataParser.addArgument("gcsUri"); + + Subparser deleteDatasetParser = subparsers.addParser("delete_dataset"); + deleteDatasetParser.addArgument("datasetId"); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + + if (ns.get("command").equals("create_dataset")) { + createDataset( + projectId, computeRegion, ns.getString("datasetName"), ns.getBoolean("multiLabel")); + } + if (ns.get("command").equals("list_datasets")) { + listDatasets(projectId, computeRegion, ns.getString("filter")); + } + if (ns.get("command").equals("get_dataset")) { + getDataset(projectId, computeRegion, ns.getString("datasetId")); + } + if (ns.get("command").equals("import_data")) { + importData(projectId, computeRegion, ns.getString("datasetId"), ns.getString("path")); + } + if (ns.get("command").equals("export_data")) { + exportData(projectId, computeRegion, ns.getString("datasetId"), ns.getString("gcsUri")); + } + if (ns.get("command").equals("delete_dataset")) { + deleteDataset(projectId, computeRegion, ns.getString("datasetId")); + } + + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/ModelApi.java b/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/ModelApi.java new file mode 100644 index 00000000000..a0f6c7b002c --- /dev/null +++ b/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/ModelApi.java @@ -0,0 +1,456 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision.samples.automl; + +// Imports the Google Cloud client library +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationEvaluationMetrics; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationEvaluationMetrics.ConfidenceMetricsEntry; +import com.google.cloud.automl.v1beta1.ImageClassificationModelMetadata; +import com.google.cloud.automl.v1beta1.ListModelEvaluationsRequest; +import com.google.cloud.automl.v1beta1.ListModelsRequest; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.Model; +import com.google.cloud.automl.v1beta1.ModelEvaluation; +import com.google.cloud.automl.v1beta1.ModelEvaluationName; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; + +import com.google.longrunning.Operation; +import com.google.protobuf.Empty; + +import java.io.IOException; +import java.io.PrintStream; +import java.util.List; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Vision API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.ModelApi' -Dexec.args='create_model + * [datasetId] test_model' + */ +public class ModelApi { + + // [START automl_vision_create_model] + /** + * Demonstrates using the AutoML client to create a model. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param dataSetId the Id of the dataset to which model is created. + * @param modelName the Name of the model. + * @param trainBudget the Budget for training the model. + * @throws Exception on AutoML Client errors + */ + public static void createModel( + String projectId, + String computeRegion, + String dataSetId, + String modelName, + String trainBudget) + throws Exception { + // Instantiates a client + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Set model metadata. + ImageClassificationModelMetadata imageClassificationModelMetadata = + Long.valueOf(trainBudget) == 0 + ? ImageClassificationModelMetadata.newBuilder().build() + : ImageClassificationModelMetadata.newBuilder() + .setTrainBudget(Long.valueOf(trainBudget)) + .build(); + + // Set model name and model metadata for the image dataset. + Model myModel = + Model.newBuilder() + .setDisplayName(modelName) + .setDatasetId(dataSetId) + .setImageClassificationModelMetadata(imageClassificationModelMetadata) + .build(); + + // Create a model with the model metadata in the region. + OperationFuture response = + client.createModelAsync(projectLocation, myModel); + + System.out.println( + String.format("Training operation name: %s", response.getInitialFuture().get().getName())); + System.out.println("Training started..."); + } + // [END automl_vision_create_model] + + // [START automl_vision_get_operation_status] + /** + * Demonstrates using the AutoML client to get operation status. + * + * @param operationFullId the complete name of a operation. For example, the name of your + * operation is projects/[projectId]/locations/us-central1/operations/[operationId]. + * @throws IOException on Input/Output errors. + */ + public static void getOperationStatus(String operationFullId) throws IOException { + AutoMlClient client = AutoMlClient.create(); + + // Get the latest state of a long-running operation. + Operation response = client.getOperationsClient().getOperation(operationFullId); + + System.out.println(String.format("Operation status: %s", response)); + } + // [END automl_vision_get_operation_status] + + // [START automl_vision_list_models] + /** + * Demonstrates using the AutoML client to list all models. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param filter - Filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listModels(String projectId, String computeRegion, String filter) + throws IOException { + AutoMlClient client = AutoMlClient.create(); + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Create list models request + ListModelsRequest listModelsRequest = + ListModelsRequest.newBuilder() + .setParent(projectLocation.toString()) + .setFilter(filter) + .build(); + + System.out.println("List of models:"); + for (Model model : client.listModels(listModelsRequest).iterateAll()) { + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Image classification model metadata:"); + System.out.println( + "Tranning budget: " + model.getImageClassificationModelMetadata().getTrainBudget()); + System.out.println( + "Tranning cost: " + model.getImageClassificationModelMetadata().getTrainCost()); + System.out.println( + String.format( + "Stop reason: %s", model.getImageClassificationModelMetadata().getStopReason())); + System.out.println( + String.format( + "Base model id: %s", model.getImageClassificationModelMetadata().getBaseModelId())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } + } + // [END automl_vision_list_models] + + // [START automl_vision_get_model] + /** + * Demonstrates using the AutoML client to get model details. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @throws IOException on Input/Output errors. + */ + public static void getModel(String projectId, String computeRegion, String modelId) + throws IOException { + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Get complete detail of the model. + Model model = client.getModel(modelFullId); + + // Display the model information. + System.out.println(String.format("Model name: %s", model.getName())); + System.out.println( + String.format( + "Model id: %s", model.getName().split("/")[model.getName().split("/").length - 1])); + System.out.println(String.format("Model display name: %s", model.getDisplayName())); + System.out.println("Image classification model metadata:"); + System.out.println( + "Tranning budget: " + model.getImageClassificationModelMetadata().getTrainBudget()); + System.out.println( + "Tranning cost:" + model.getImageClassificationModelMetadata().getTrainCost()); + System.out.println( + String.format( + "Stop reason: %s", model.getImageClassificationModelMetadata().getStopReason())); + System.out.println( + String.format( + "Base model id: %s", model.getImageClassificationModelMetadata().getBaseModelId())); + System.out.println("Model create time:"); + System.out.println(String.format("\tseconds: %s", model.getCreateTime().getSeconds())); + System.out.println(String.format("\tnanos: %s", model.getCreateTime().getNanos())); + System.out.println(String.format("Model deployment state: %s", model.getDeploymentState())); + } + // [END automl_vision_get_model] + + // [START automl_vision_list_model_evaluations] + /** + * Demonstrates using the AutoML client to list model evaluations. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param filter the Filter expression. + * @throws IOException on Input/Output errors. + */ + public static void listModelEvaluations( + String projectId, String computeRegion, String modelId, String filter) throws IOException { + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Create list model evaluations request + ListModelEvaluationsRequest modelEvaluationsrequest = + ListModelEvaluationsRequest.newBuilder() + .setParent(modelFullId.toString()) + .setFilter(filter) + .build(); + + System.out.println("List of model evaluations:"); + // List all the model evaluations in the model by applying filter. + for (ModelEvaluation element : + client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + System.out.println(element); + } + } + // [END automl_vision_list_model_evaluations] + + // [START automl_vision_get_model_evaluation] + /** + * Demonstrates using the AutoML client to get model evaluations. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param modelEvaluationId the Id of your model evaluation. + * @throws IOException on Input/Output errors. + */ + public static void getModelEvaluation( + String projectId, String computeRegion, String modelId, String modelEvaluationId) + throws IOException { + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model evaluation. + ModelEvaluationName modelEvaluationFullId = + ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); + // Perform the AutoML Model request to get Model Evaluation information + ModelEvaluation response = client.getModelEvaluation(modelEvaluationFullId); + + System.out.println(response); + } + // [END automl_vision_get_model_evaluation] + + // [START automl_vision_display_evaluation] + /** + * Demonstrates using the AutoML client to display model evaluation. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @param filter the filter expression. + * @throws IOException on Input/Output errors. + */ + public static void displayEvaluation( + String projectId, String computeRegion, String modelId, String filter) throws IOException { + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // List all the model evaluations in the model by applying filter. + ListModelEvaluationsRequest modelEvaluationsrequest = + ListModelEvaluationsRequest.newBuilder() + .setParent(modelFullId.toString()) + .setFilter(filter) + .build(); + + // Iterate through the results. + String modelEvaluationId = ""; + for (ModelEvaluation element : + client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) { + if (element.getAnnotationSpecId() != null) { + modelEvaluationId = element.getName().split("/")[element.getName().split("/").length - 1]; + } + } + + // Resource name for the model evaluation. + ModelEvaluationName modelEvaluationFullId = + ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId); + + // Get a model evaluation. + ModelEvaluation modelEvaluation = client.getModelEvaluation(modelEvaluationFullId); + + ClassificationEvaluationMetrics classMetrics = + modelEvaluation.getClassificationEvaluationMetrics(); + List confidenceMetricsEntries = + classMetrics.getConfidenceMetricsEntryList(); + + // Showing model score based on threshold of 0.5 + for (ConfidenceMetricsEntry confidenceMetricsEntry : confidenceMetricsEntries) { + if (confidenceMetricsEntry.getConfidenceThreshold() == 0.5) { + System.out.println("Precision and recall are based on a score threshold of 0.5"); + System.out.println( + String.format("Model Precision: %.2f ", confidenceMetricsEntry.getPrecision() * 100) + + '%'); + System.out.println( + String.format("Model Recall: %.2f ", confidenceMetricsEntry.getRecall() * 100) + '%'); + System.out.println( + String.format("Model F1 score: %.2f ", confidenceMetricsEntry.getF1Score() * 100) + + '%'); + System.out.println( + String.format( + "Model Precision@1: %.2f ", confidenceMetricsEntry.getPrecisionAt1() * 100) + + '%'); + System.out.println( + String.format("Model Recall@1: %.2f ", confidenceMetricsEntry.getRecallAt1() * 100) + + '%'); + System.out.println( + String.format("Model F1 score@1: %.2f ", confidenceMetricsEntry.getF1ScoreAt1() * 100) + + '%'); + } + } + } + // [END automl_vision_display_evaluation] + + // [START automl_vision_delete_model] + /** + * Demonstrates using the AutoML client to delete a model. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model. + * @throws Exception on AutoML Client errors + */ + public static void deleteModel(String projectId, String computeRegion, String modelId) + throws Exception { + AutoMlClient client = AutoMlClient.create(); + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId); + + // Delete a model. + Empty response = client.deleteModelAsync(modelFullId).get(); + + System.out.println("Model deletion started..."); + } + // [END automl_vision_delete_model] + + public static void main(String[] args) throws Exception { + ModelApi modelApi = new ModelApi(); + modelApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws Exception { + ArgumentParser parser = + ArgumentParsers.newFor("ModelApi") + .build() + .defaultHelp(true) + .description("Model API operations."); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser createModelParser = subparsers.addParser("create_model"); + createModelParser.addArgument("datasetId"); + createModelParser.addArgument("modelName"); + createModelParser.addArgument("trainBudget"); + + Subparser listModelParser = subparsers.addParser("list_models"); + listModelParser + .addArgument("filter") + .nargs("?") + .setDefault("imageClassificationModelMetadata:*"); + + Subparser getModelParser = subparsers.addParser("get_model"); + getModelParser.addArgument("modelId"); + + Subparser listModelEvaluationsParser = subparsers.addParser("list_model_evaluations"); + listModelEvaluationsParser.addArgument("modelId"); + listModelEvaluationsParser.addArgument("filter").nargs("?").setDefault(""); + + Subparser getModelEvaluationParser = subparsers.addParser("get_model_evaluation"); + getModelEvaluationParser.addArgument("modelId"); + getModelEvaluationParser.addArgument("modelEvaluationId"); + + Subparser displayEvaluationParser = subparsers.addParser("display_evaluation"); + displayEvaluationParser.addArgument("modelId"); + displayEvaluationParser.addArgument("filter").nargs("?").setDefault(""); + + Subparser deleteModelParser = subparsers.addParser("delete_model"); + deleteModelParser.addArgument("modelId"); + + Subparser getOperationStatusParser = subparsers.addParser("get_operation_status"); + getOperationStatusParser.addArgument("operationFullId"); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + if (ns.get("command").equals("create_model")) { + createModel( + projectId, + computeRegion, + ns.getString("datasetId"), + ns.getString("modelName"), + ns.getString("trainBudget")); + } + if (ns.get("command").equals("list_models")) { + listModels(projectId, computeRegion, ns.getString("filter")); + } + if (ns.get("command").equals("get_model")) { + getModel(projectId, computeRegion, ns.getString("modelId")); + } + if (ns.get("command").equals("list_model_evaluations")) { + listModelEvaluations( + projectId, computeRegion, ns.getString("modelId"), ns.getString("filter")); + } + if (ns.get("command").equals("get_model_evaluation")) { + getModelEvaluation( + projectId, computeRegion, ns.getString("modelId"), ns.getString("modelEvaluationId")); + } + if (ns.get("command").equals("display_evaluation")) { + displayEvaluation( + projectId, computeRegion, ns.getString("modelId"), ns.getString("filter")); + } + if (ns.get("command").equals("delete_model")) { + deleteModel(projectId, computeRegion, ns.getString("modelId")); + } + if (ns.get("command").equals("get_operation_status")) { + getOperationStatus(ns.getString("operationFullId")); + } + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/PredictionApi.java b/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/PredictionApi.java new file mode 100644 index 00000000000..a76c3f40caf --- /dev/null +++ b/vision/automl/src/main/java/com/google/cloud/vision/samples/automl/PredictionApi.java @@ -0,0 +1,142 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This application demonstrates how to perform basic operations on prediction + * with the Google AutoML Vision API. + * + * For more information, the documentation at + * https://cloud.google.com/vision/automl/docs. + */ + +package com.google.cloud.vision.samples.automl; + +// Imports the Google Cloud client library +import com.google.cloud.automl.v1beta1.AnnotationPayload; +import com.google.cloud.automl.v1beta1.ExamplePayload; +import com.google.cloud.automl.v1beta1.Image; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.PredictResponse; +import com.google.cloud.automl.v1beta1.PredictionServiceClient; +import com.google.protobuf.ByteString; + +import java.io.IOException; +import java.io.PrintStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Vision API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.PredictionApi' -Dexec.args='predict + * [modelId] [path-to-image] [scoreThreshold]' + */ +public class PredictionApi { + + // [START automl_vision_predict] + + /** + * Demonstrates using the AutoML client to predict an image. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model which will be used for text classification. + * @param filePath the Local text file path of the content to be classified. + * @param scoreThreshold the Confidence score. Only classifications with confidence score above + * scoreThreshold are displayed. + * @throws IOException on Input/Output errors. + */ + public static void predict( + String projectId, + String computeRegion, + String modelId, + String filePath, + String scoreThreshold) + throws IOException { + + // Instantiate client for prediction service. + PredictionServiceClient predictionClient = PredictionServiceClient.create(); + + // Get the full path of the model. + ModelName name = ModelName.of(projectId, computeRegion, modelId); + + // Read the image and assign to payload. + ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get(filePath))); + Image image = Image.newBuilder().setImageBytes(content).build(); + ExamplePayload examplePayload = ExamplePayload.newBuilder().setImage(image).build(); + + // Additional parameters that can be provided for prediction e.g. Score Threshold + Map params = new HashMap<>(); + if (scoreThreshold != null) { + params.put("scoreThreshold", scoreThreshold); + } + // Perform the AutoML Prediction request + PredictResponse response = predictionClient.predict(name, examplePayload, params); + + System.out.println("Prediction results:"); + for (AnnotationPayload annotationPayload : response.getPayloadList()) { + System.out.println("Predicted class name :" + annotationPayload.getDisplayName()); + System.out.println( + "Predicted class score :" + annotationPayload.getClassification().getScore()); + } + } + // [END automl_vision_predict] + + public static void main(String[] args) throws IOException { + PredictionApi predictionApi = new PredictionApi(); + predictionApi.argsHelper(args, System.out); + } + + public static void argsHelper(String[] args, PrintStream out) throws IOException { + ArgumentParser parser = + ArgumentParsers.newFor("PredictionApi") + .build() + .defaultHelp(true) + .description("Prediction API Operation"); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser predictParser = subparsers.addParser("predict"); + predictParser.addArgument("modelId"); + predictParser.addArgument("filePath"); + predictParser.addArgument("scoreThreshold").nargs("?").type(String.class).setDefault(""); + + String projectId = System.getenv("PROJECT_ID"); + String computeRegion = System.getenv("REGION_NAME"); + + Namespace ns = null; + try { + ns = parser.parseArgs(args); + if (ns.get("command").equals("predict")) { + predict( + projectId, + computeRegion, + ns.getString("modelId"), + ns.getString("filePath"), + ns.getString("scoreThreshold")); + } + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/DatasetApiIT.java b/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/DatasetApiIT.java new file mode 100644 index 00000000000..41cf9dec9a5 --- /dev/null +++ b/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/DatasetApiIT.java @@ -0,0 +1,107 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision.samples.automl; + +import static com.google.common.truth.Truth.assertThat; +import static java.lang.Boolean.FALSE; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for Automl vision "Dataset API" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class DatasetApiIT { + + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String BUCKET = PROJECT_ID + "-vcm"; + private static final String COMPUTE_REGION = "us-central1"; + private static final String DATASET_NAME = "test_vision_dataset"; + private ByteArrayOutputStream bout; + private PrintStream out; + private DatasetApi app; + private String datasetId; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testCreateImportDeleteDataset() throws Exception { + // Act + DatasetApi.createDataset(PROJECT_ID, COMPUTE_REGION, DATASET_NAME, FALSE); + + // Assert + String got = bout.toString(); + datasetId = + bout.toString() + .split("\n")[0] + .split("/")[(bout.toString().split("\n")[0]).split("/").length - 1]; + assertThat(got).contains("Dataset id:"); + + // Act + DatasetApi.importData( + PROJECT_ID, COMPUTE_REGION, datasetId, "gs://" + BUCKET + "/flower_traindata.csv"); + + // Assert + got = bout.toString(); + assertThat(got).contains("Dataset id:"); + + // Act + DatasetApi.deleteDataset(PROJECT_ID, COMPUTE_REGION, datasetId); + + // Assert + got = bout.toString(); + assertThat(got).contains("Dataset deleted."); + } + + @Test + public void testListGetDatasets() throws Exception { + // Act + DatasetApi.listDatasets(PROJECT_ID, COMPUTE_REGION, "imageClassificationDatasetMetadata:*"); + + // Assert + String got = bout.toString(); + datasetId = + bout.toString() + .split("\n")[1] + .split("/")[(bout.toString().split("\n")[1]).split("/").length - 1]; + assertThat(got).contains("Dataset id:"); + + // Act + DatasetApi.getDataset(PROJECT_ID, COMPUTE_REGION, datasetId); + + // Assert + got = bout.toString(); + + assertThat(got).contains("Dataset id:"); + } +} diff --git a/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/ModelApiIT.java b/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/ModelApiIT.java new file mode 100644 index 00000000000..9dae82604a3 --- /dev/null +++ b/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/ModelApiIT.java @@ -0,0 +1,88 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision.samples.automl; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for vision "Model API" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class ModelApiIT { + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String COMPUTE_REGION = "us-central1"; + private ByteArrayOutputStream bout; + private PrintStream out; + private ModelApi app; + private String modelId; + private String modelEvaluationId; + + @Before + public void setUp() { + + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testModelApi() throws Exception { + // Act + ModelApi.listModels(PROJECT_ID, COMPUTE_REGION, ""); + + // Assert + String got = bout.toString(); + modelId = got.split("\n")[1].split("/")[got.split("\n")[1].split("/").length - 1]; + assertThat(got).contains("Model id:"); + + // Act + ModelApi.getModel(PROJECT_ID, COMPUTE_REGION, modelId); + + // Assert + got = bout.toString(); + assertThat(got).contains("Model name:"); + + // Act + ModelApi.listModelEvaluations(PROJECT_ID, COMPUTE_REGION, modelId, ""); + + // Assert + got = bout.toString(); + modelEvaluationId = got.split("List of model evaluations:")[1].split("\"")[1].split("/")[7]; + assertThat(got).contains("name:"); + + // Act + ModelApi.getModelEvaluation(PROJECT_ID, COMPUTE_REGION, modelId, modelEvaluationId); + + // Assert + got = bout.toString(); + assertThat(got).contains("name:"); + + } +} diff --git a/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/PredictionApiIT.java b/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/PredictionApiIT.java new file mode 100644 index 00000000000..db0e0401a01 --- /dev/null +++ b/vision/automl/src/test/java/com/google/cloud/vision/samples/automl/PredictionApiIT.java @@ -0,0 +1,64 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision.samples.automl; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for vision "PredictionAPI" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class PredictionApiIT { + private static final String COMPUTE_REGION = "us-central1"; + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String modelId = "620201829169141520"; + private static final String filePath = "./resources/dandelion.jpg"; + private static final String scoreThreshold = "0.7"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PredictionApi app; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testPredict() throws Exception { + // Act + PredictionApi.predict(PROJECT_ID, COMPUTE_REGION, modelId, filePath, scoreThreshold); + + // Assert + String got = bout.toString(); + assertThat(got).contains("dandelion"); + } +}